diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-12-18 16:32:37 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-12-18 16:32:37 +0000 |
commit | b33cbea86d46015ca0f2ca0eb4182220b9bfa2f3 (patch) | |
tree | f1bd013b30dd89870496a1105433d80675ffbc83 | |
parent | d25238fead4b5c07fe05bc2a29df759075a93bdc (diff) | |
download | rspamd-vstakhov-gpt-ollama.tar.gz rspamd-vstakhov-gpt-ollama.zip |
[Feature] GPT: Add ollama supportvstakhov-gpt-ollama
-rw-r--r-- | src/plugins/lua/gpt.lua | 254 |
1 files changed, 223 insertions, 31 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index 36938c0d1..feccae73f 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -22,7 +22,7 @@ if confighelp then "Performs postfiltering using GPT model", [[ gpt { - # Supported types: openai + # Supported types: openai, ollama type = "openai"; # Your key to access the API api_key = "xxx"; @@ -155,13 +155,17 @@ end local function maybe_extract_json(str) -- Find the first opening brace - local startPos = str:find("{") + local startPos, endPos = str:find('json%s*{') + if not startPos then + startPos, endPos = str:find('{') + end if not startPos then return nil end + startPos = endPos - 1 local openBraces = 0 - local endPos = startPos + endPos = startPos local len = #str -- Iterate through the string to find matching braces @@ -225,6 +229,7 @@ local function default_conversion(task, input) reply = parser:get_object() if type(reply) == 'table' and reply.probability then + lua_util.debugm(N, task, 'extracted probability: %s', reply.probability) local spam_score = tonumber(reply.probability) if not spam_score then @@ -249,7 +254,87 @@ local function default_conversion(task, input) return end -local function openai_gpt_check(task) +local function ollama_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.message) ~= 'table' then + rspamd_logger.errx(task, 'bad message in reply') + return + end + + local first_message = reply.message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + + -- Apply heuristic to extract JSON + first_message = maybe_extract_json(first_message) or first_message + + parser = ucl.parser() + res, err = parser:parse_string(first_message) + if not res then + rspamd_logger.errx(task, 'cannot parse JSON gpt reply: %s', err) + return + end + + reply = parser:get_object() + + if type(reply) == 'table' and reply.probability then + lua_util.debugm(N, task, 'extracted probability: %s', reply.probability) + local spam_score = tonumber(reply.probability) + + if not spam_score then + -- Maybe we need GPT to convert GPT reply here? + if reply.probability == "high" then + spam_score = 0.9 + elseif reply.probability == "low" then + spam_score = 0.1 + else + rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + end + end + + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + + return spam_score + end + + rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) + return +end + +local function get_meta_llm_content(task) + local url_content = "Url domains: no urls found" + if task:has_urls() then + local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 } + url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u) + return u:get_tld() or '' + end, urls or {})), ', ') + end + + local from_or_empty = ((task:get_from('mime') or E)[1] or E) + local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr) + lua_util.debugm(N, task, "gpt urls: %s", url_content) + lua_util.debugm(N, task, "gpt from: %s", from_content) + + return url_content, from_content +end + +local function default_llm_check(task) local ret, content = settings.condition(task) if not ret then @@ -302,18 +387,7 @@ local function openai_gpt_check(task) end - local url_content = "Url domains: no urls found" - if task:has_urls() then - local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 } - url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u) - return u:get_tld() or '' - end, urls or {})), ', ') - end - - local from_or_empty = ((task:get_from('mime') or E)[1] or E) - local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr) - lua_util.debugm(N, task, "gpt urls: %s", url_content) - lua_util.debugm(N, task, "gpt from: %s", from_content) + local from_content, url_content = get_meta_llm_content(task) local body = { model = settings.model, @@ -364,43 +438,161 @@ local function openai_gpt_check(task) rspamd_http.request(http_params) end +local function ollama_check(task) + local ret, content = settings.condition(task) + + if not ret then + rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) + return + end + + if not content then + lua_util.debugm(N, task, "no content to send to gpt classification") + return + end + + lua_util.debugm(N, task, "sending content to gpt: %s", content) + + local upstream + + local function on_reply(err, code, body) + + if err then + rspamd_logger.errx(task, 'request failed: %s', err) + upstream:fail() + return + end + + upstream:ok() + lua_util.debugm(N, task, "got reply: %s", body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end + + local reply = settings.reply_conversion(task, body) + if not reply then + return + end + + if reply > 0.75 then + task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) + if settings.autolearn then + task:set_flag("learn_spam") + end + elseif reply < 0.25 then + task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) + if settings.autolearn then + task:set_flag("learn_ham") + end + else + lua_util.debugm(N, task, "uncertain result: %s", reply) + end + + end + + local from_content, url_content = get_meta_llm_content(task) + + local body = { + stream = false, + model = settings.model, + max_tokens = settings.max_tokens, + temperature = settings.temperature, + response_format = { type = "json_object" }, + messages = { + { + role = 'system', + content = settings.prompt + }, + { + role = 'user', + content = 'Subject: ' .. task:get_subject() or '', + }, + { + role = 'user', + content = from_content, + }, + { + role = 'user', + content = url_content, + }, + { + role = 'user', + content = content + } + } + } + + upstream = settings.upstreams:get_upstream_round_robin() + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = on_reply, + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } + + rspamd_http.request(http_params) +end + local function gpt_check(task) return settings.specific_check(task) end +local types_map = { + openai = { + check = default_llm_check, + condition = default_condition, + conversion = default_conversion, + require_passkey = true, + }, + ollama = { + check = ollama_check, + condition = default_condition, + conversion = ollama_conversion, + require_passkey = false, + }, +} + local opts = rspamd_config:get_all_opt('gpt') if opts then settings = lua_util.override_defaults(settings, opts) - if not settings.api_key then - rspamd_logger.warnx(rspamd_config, 'no api_key is specified, disabling module') - lua_util.disable_module(N, "config") + if not settings.prompt then + settings.prompt = "You will be provided with the email message, subject, from and url domains, " .. + "and your task is to evaluate the probability to be spam as number from 0 to 1, " .. + "output result as JSON with 'probability' field." + end + local llm_type = types_map[settings.type] + if not llm_type then + rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type) + lua_util.disable_module(N, "config") return end + settings.specific_check = llm_type.check + if settings.condition then settings.condition = load(settings.condition)() else - settings.condition = default_condition + settings.condition = llm_type.condition end if settings.reply_conversion then settings.reply_conversion = load(settings.reply_conversion)() else - settings.reply_conversion = default_conversion - end - - if not settings.prompt then - settings.prompt = "You will be provided with the email message, subject, from and url domains, " .. - "and your task is to evaluate the probability to be spam as number from 0 to 1, " .. - "output result as JSON with 'probability' field." + settings.reply_conversion = llm_type.conversion end - if settings.type == 'openai' then - settings.specific_check = openai_gpt_check - else - rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type) + if not settings.api_key and llm_type.require_passkey then + rspamd_logger.warnx(rspamd_config, 'no api_key is specified for LLM type %s, disabling module', settings.type) lua_util.disable_module(N, "config") + return end |