diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2025-01-27 19:19:28 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2025-01-27 19:19:28 +0000 |
commit | 4bc3ef8f30f3f943aa79755aa4448fc8f1cae1de (patch) | |
tree | 9902c322c3618db078178d5756d7371dc5dd53bb | |
parent | de7717d9dabcc1cce1eb8a3a6e7b88b9f9f3ec3a (diff) | |
download | rspamd-vstakhov-gpt-consensus.tar.gz rspamd-vstakhov-gpt-consensus.zip |
[Feature] Support LLM models consensusvstakhov-gpt-consensus
-rw-r--r-- | src/plugins/lua/gpt.lua | 236 |
1 files changed, 145 insertions, 91 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index e4a77c6dd..4888eaa19 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -319,6 +319,47 @@ local function ollama_conversion(task, input) return end +local function check_consensus(task, results) + for _, result in ipairs(results) do + if not result.checked then + return + end + end + + local nspam, nham = 0, 0 + local max_spam_prob, max_ham_prob = 0, 0 + + for _, result in ipairs(results) do + if result.success then + if result.probability > 0.5 then + nspam = nspam + 1 + max_spam_prob = math.max(max_spam_prob, result.probability) + lua_util.debugm(N, task, "model: %s; spam: %s", result.model, result.probability) + else + nham = nham + 1 + max_ham_prob = math.min(max_ham_prob, result.probability) + lua_util.debugm(N, task, "model: %s; ham: %s", result.model, result.probability) + end + end + end + + if nspam > nham and max_spam_prob > 0.75 then + task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob)) + if settings.autolearn then + task:set_flag("learn_spam") + end + elseif nham > nspam and max_ham_prob < 0.25 then + task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob)) + if settings.autolearn then + task:set_flag("learn_ham") + end + else + -- No consensus + lua_util.debugm(N, task, "no consensus") + end + +end + local function get_meta_llm_content(task) local url_content = "Url domains: no urls found" if task:has_urls() then @@ -353,40 +394,36 @@ local function default_llm_check(task) local upstream - local function on_reply(err, code, body) + local results = {} - if err then - rspamd_logger.errx(task, 'request failed: %s', err) - upstream:fail() - return - end + local function gen_reply_closure(model, idx) + return function(err, code, body) + results[idx].checked = true + if err then + rspamd_logger.errx(task, '%s: request failed: %s', model, err) + upstream:fail() + check_consensus(task, results) + return + end - upstream:ok() - lua_util.debugm(N, task, "got reply: %s", body) - if code ~= 200 then - rspamd_logger.errx(task, 'bad reply: %s', body) - return - end + upstream:ok() + lua_util.debugm(N, task, "%s: got reply: %s", model, body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end - local reply = settings.reply_conversion(task, body) - if not reply then - return - end + local reply = settings.reply_conversion(task, body) - if reply > 0.75 then - task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_spam") - end - elseif reply < 0.25 then - task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_ham") + results[idx].model = model + + if reply then + results[idx].success = true + results[idx].probability = reply end - else - lua_util.debugm(N, task, "uncertain result: %s", reply) - end + check_consensus(task, results) + end end local from_content, url_content = get_meta_llm_content(task) @@ -424,24 +461,38 @@ local function default_llm_check(task) body.response_format = { type = "json_object" } end + if type(settings.model) == 'string' then + settings.model = { settings.model } + end + upstream = settings.upstreams:get_upstream_round_robin() - local http_params = { - url = settings.url, - mime_type = 'application/json', - timeout = settings.timeout, - log_obj = task, - callback = on_reply, - headers = { - ['Authorization'] = 'Bearer ' .. settings.api_key, - }, - keepalive = true, - body = ucl.to_format(body, 'json-compact', true), - task = task, - upstream = upstream, - use_gzip = true, - } + for idx, model in ipairs(settings.model) do + results[idx] = { + success = false, + checked = false + } + body.model = model + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = gen_reply_closure(model, idx), + headers = { + ['Authorization'] = 'Bearer ' .. settings.api_key, + }, + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } - rspamd_http.request(http_params) + if not rspamd_http.request(http_params) then + results[idx].checked = true + end + + end end local function ollama_check(task) @@ -460,45 +511,44 @@ local function ollama_check(task) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream + local results = {} + + local function gen_reply_closure(model, idx) + return function(err, code, body) + results[idx].checked = true + if err then + rspamd_logger.errx(task, '%s: request failed: %s', model, err) + upstream:fail() + check_consensus(task, results) + return + end - local function on_reply(err, code, body) - - if err then - rspamd_logger.errx(task, 'request failed: %s', err) - upstream:fail() - return - end + upstream:ok() + lua_util.debugm(N, task, "%s: got reply: %s", model, body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end - upstream:ok() - lua_util.debugm(N, task, "got reply: %s", body) - if code ~= 200 then - rspamd_logger.errx(task, 'bad reply: %s', body) - return - end + local reply = settings.reply_conversion(task, body) - local reply = settings.reply_conversion(task, body) - if not reply then - return - end + results[idx].model = model - if reply > 0.75 then - task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_spam") - end - elseif reply < 0.25 then - task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_ham") + if reply then + results[idx].success = true + results[idx].probability = reply end - else - lua_util.debugm(N, task, "uncertain result: %s", reply) - end + check_consensus(task, results) + end end local from_content, url_content = get_meta_llm_content(task) + if type(settings.model) == 'string' then + settings.model = { settings.model } + end + local body = { stream = false, model = settings.model, @@ -528,26 +578,30 @@ local function ollama_check(task) } } - -- Conditionally add response_format - if settings.include_response_format then - body.response_format = { type = "json_object" } - end + for i, model in ipairs(settings.model) do + -- Conditionally add response_format + if settings.include_response_format then + body.response_format = { type = "json_object" } + end - upstream = settings.upstreams:get_upstream_round_robin() - local http_params = { - url = settings.url, - mime_type = 'application/json', - timeout = settings.timeout, - log_obj = task, - callback = on_reply, - keepalive = true, - body = ucl.to_format(body, 'json-compact', true), - task = task, - upstream = upstream, - use_gzip = true, - } + body.model = model + + upstream = settings.upstreams:get_upstream_round_robin() + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = gen_reply_closure(model, i), + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } - rspamd_http.request(http_params) + rspamd_http.request(http_params) + end end local function gpt_check(task) |