aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-12-18 16:32:37 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2024-12-18 16:32:37 +0000
commitb33cbea86d46015ca0f2ca0eb4182220b9bfa2f3 (patch)
treef1bd013b30dd89870496a1105433d80675ffbc83 /src
parentd25238fead4b5c07fe05bc2a29df759075a93bdc (diff)
downloadrspamd-b33cbea86d46015ca0f2ca0eb4182220b9bfa2f3.tar.gz
rspamd-b33cbea86d46015ca0f2ca0eb4182220b9bfa2f3.zip
[Feature] GPT: Add ollama supportvstakhov-gpt-ollama
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/gpt.lua254
1 files changed, 223 insertions, 31 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index 36938c0d1..feccae73f 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -22,7 +22,7 @@ if confighelp then
"Performs postfiltering using GPT model",
[[
gpt {
- # Supported types: openai
+ # Supported types: openai, ollama
type = "openai";
# Your key to access the API
api_key = "xxx";
@@ -155,13 +155,17 @@ end
local function maybe_extract_json(str)
-- Find the first opening brace
- local startPos = str:find("{")
+ local startPos, endPos = str:find('json%s*{')
+ if not startPos then
+ startPos, endPos = str:find('{')
+ end
if not startPos then
return nil
end
+ startPos = endPos - 1
local openBraces = 0
- local endPos = startPos
+ endPos = startPos
local len = #str
-- Iterate through the string to find matching braces
@@ -225,6 +229,7 @@ local function default_conversion(task, input)
reply = parser:get_object()
if type(reply) == 'table' and reply.probability then
+ lua_util.debugm(N, task, 'extracted probability: %s', reply.probability)
local spam_score = tonumber(reply.probability)
if not spam_score then
@@ -249,7 +254,87 @@ local function default_conversion(task, input)
return
end
-local function openai_gpt_check(task)
+local function ollama_conversion(task, input)
+ local parser = ucl.parser()
+ local res, err = parser:parse_string(input)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+ return
+ end
+ local reply = parser:get_object()
+ if not reply then
+ rspamd_logger.errx(task, 'cannot get object from reply')
+ return
+ end
+
+ if type(reply.message) ~= 'table' then
+ rspamd_logger.errx(task, 'bad message in reply')
+ return
+ end
+
+ local first_message = reply.message.content
+
+ if not first_message then
+ rspamd_logger.errx(task, 'no content in the first message')
+ return
+ end
+
+ -- Apply heuristic to extract JSON
+ first_message = maybe_extract_json(first_message) or first_message
+
+ parser = ucl.parser()
+ res, err = parser:parse_string(first_message)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse JSON gpt reply: %s', err)
+ return
+ end
+
+ reply = parser:get_object()
+
+ if type(reply) == 'table' and reply.probability then
+ lua_util.debugm(N, task, 'extracted probability: %s', reply.probability)
+ local spam_score = tonumber(reply.probability)
+
+ if not spam_score then
+ -- Maybe we need GPT to convert GPT reply here?
+ if reply.probability == "high" then
+ spam_score = 0.9
+ elseif reply.probability == "low" then
+ spam_score = 0.1
+ else
+ rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability)
+ end
+ end
+
+ if type(reply.usage) == 'table' then
+ rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+ end
+
+ return spam_score
+ end
+
+ rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+ return
+end
+
+local function get_meta_llm_content(task)
+ local url_content = "Url domains: no urls found"
+ if task:has_urls() then
+ local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
+ url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
+ return u:get_tld() or ''
+ end, urls or {})), ', ')
+ end
+
+ local from_or_empty = ((task:get_from('mime') or E)[1] or E)
+ local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr)
+ lua_util.debugm(N, task, "gpt urls: %s", url_content)
+ lua_util.debugm(N, task, "gpt from: %s", from_content)
+
+ return url_content, from_content
+end
+
+local function default_llm_check(task)
local ret, content = settings.condition(task)
if not ret then
@@ -302,18 +387,7 @@ local function openai_gpt_check(task)
end
- local url_content = "Url domains: no urls found"
- if task:has_urls() then
- local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
- url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
- return u:get_tld() or ''
- end, urls or {})), ', ')
- end
-
- local from_or_empty = ((task:get_from('mime') or E)[1] or E)
- local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr)
- lua_util.debugm(N, task, "gpt urls: %s", url_content)
- lua_util.debugm(N, task, "gpt from: %s", from_content)
+ local from_content, url_content = get_meta_llm_content(task)
local body = {
model = settings.model,
@@ -364,43 +438,161 @@ local function openai_gpt_check(task)
rspamd_http.request(http_params)
end
+local function ollama_check(task)
+ local ret, content = settings.condition(task)
+
+ if not ret then
+ rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
+ return
+ end
+
+ if not content then
+ lua_util.debugm(N, task, "no content to send to gpt classification")
+ return
+ end
+
+ lua_util.debugm(N, task, "sending content to gpt: %s", content)
+
+ local upstream
+
+ local function on_reply(err, code, body)
+
+ if err then
+ rspamd_logger.errx(task, 'request failed: %s', err)
+ upstream:fail()
+ return
+ end
+
+ upstream:ok()
+ lua_util.debugm(N, task, "got reply: %s", body)
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'bad reply: %s', body)
+ return
+ end
+
+ local reply = settings.reply_conversion(task, body)
+ if not reply then
+ return
+ end
+
+ if reply > 0.75 then
+ task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
+ if settings.autolearn then
+ task:set_flag("learn_spam")
+ end
+ elseif reply < 0.25 then
+ task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
+ if settings.autolearn then
+ task:set_flag("learn_ham")
+ end
+ else
+ lua_util.debugm(N, task, "uncertain result: %s", reply)
+ end
+
+ end
+
+ local from_content, url_content = get_meta_llm_content(task)
+
+ local body = {
+ stream = false,
+ model = settings.model,
+ max_tokens = settings.max_tokens,
+ temperature = settings.temperature,
+ response_format = { type = "json_object" },
+ messages = {
+ {
+ role = 'system',
+ content = settings.prompt
+ },
+ {
+ role = 'user',
+ content = 'Subject: ' .. task:get_subject() or '',
+ },
+ {
+ role = 'user',
+ content = from_content,
+ },
+ {
+ role = 'user',
+ content = url_content,
+ },
+ {
+ role = 'user',
+ content = content
+ }
+ }
+ }
+
+ upstream = settings.upstreams:get_upstream_round_robin()
+ local http_params = {
+ url = settings.url,
+ mime_type = 'application/json',
+ timeout = settings.timeout,
+ log_obj = task,
+ callback = on_reply,
+ keepalive = true,
+ body = ucl.to_format(body, 'json-compact', true),
+ task = task,
+ upstream = upstream,
+ use_gzip = true,
+ }
+
+ rspamd_http.request(http_params)
+end
+
local function gpt_check(task)
return settings.specific_check(task)
end
+local types_map = {
+ openai = {
+ check = default_llm_check,
+ condition = default_condition,
+ conversion = default_conversion,
+ require_passkey = true,
+ },
+ ollama = {
+ check = ollama_check,
+ condition = default_condition,
+ conversion = ollama_conversion,
+ require_passkey = false,
+ },
+}
+
local opts = rspamd_config:get_all_opt('gpt')
if opts then
settings = lua_util.override_defaults(settings, opts)
- if not settings.api_key then
- rspamd_logger.warnx(rspamd_config, 'no api_key is specified, disabling module')
- lua_util.disable_module(N, "config")
+ if not settings.prompt then
+ settings.prompt = "You will be provided with the email message, subject, from and url domains, " ..
+ "and your task is to evaluate the probability to be spam as number from 0 to 1, " ..
+ "output result as JSON with 'probability' field."
+ end
+ local llm_type = types_map[settings.type]
+ if not llm_type then
+ rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
+ lua_util.disable_module(N, "config")
return
end
+ settings.specific_check = llm_type.check
+
if settings.condition then
settings.condition = load(settings.condition)()
else
- settings.condition = default_condition
+ settings.condition = llm_type.condition
end
if settings.reply_conversion then
settings.reply_conversion = load(settings.reply_conversion)()
else
- settings.reply_conversion = default_conversion
- end
-
- if not settings.prompt then
- settings.prompt = "You will be provided with the email message, subject, from and url domains, " ..
- "and your task is to evaluate the probability to be spam as number from 0 to 1, " ..
- "output result as JSON with 'probability' field."
+ settings.reply_conversion = llm_type.conversion
end
- if settings.type == 'openai' then
- settings.specific_check = openai_gpt_check
- else
- rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
+ if not settings.api_key and llm_type.require_passkey then
+ rspamd_logger.warnx(rspamd_config, 'no api_key is specified for LLM type %s, disabling module', settings.type)
lua_util.disable_module(N, "config")
+
return
end