aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2025-01-27 19:19:28 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2025-01-27 19:19:28 +0000
commit4bc3ef8f30f3f943aa79755aa4448fc8f1cae1de (patch)
tree9902c322c3618db078178d5756d7371dc5dd53bb
parentde7717d9dabcc1cce1eb8a3a6e7b88b9f9f3ec3a (diff)
downloadrspamd-vstakhov-gpt-consensus.tar.gz
rspamd-vstakhov-gpt-consensus.zip
[Feature] Support LLM models consensusvstakhov-gpt-consensus
-rw-r--r--src/plugins/lua/gpt.lua236
1 files changed, 145 insertions, 91 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index e4a77c6dd..4888eaa19 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -319,6 +319,47 @@ local function ollama_conversion(task, input)
return
end
+local function check_consensus(task, results)
+ for _, result in ipairs(results) do
+ if not result.checked then
+ return
+ end
+ end
+
+ local nspam, nham = 0, 0
+ local max_spam_prob, max_ham_prob = 0, 0
+
+ for _, result in ipairs(results) do
+ if result.success then
+ if result.probability > 0.5 then
+ nspam = nspam + 1
+ max_spam_prob = math.max(max_spam_prob, result.probability)
+ lua_util.debugm(N, task, "model: %s; spam: %s", result.model, result.probability)
+ else
+ nham = nham + 1
+ max_ham_prob = math.min(max_ham_prob, result.probability)
+ lua_util.debugm(N, task, "model: %s; ham: %s", result.model, result.probability)
+ end
+ end
+ end
+
+ if nspam > nham and max_spam_prob > 0.75 then
+ task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob))
+ if settings.autolearn then
+ task:set_flag("learn_spam")
+ end
+ elseif nham > nspam and max_ham_prob < 0.25 then
+ task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob))
+ if settings.autolearn then
+ task:set_flag("learn_ham")
+ end
+ else
+ -- No consensus
+ lua_util.debugm(N, task, "no consensus")
+ end
+
+end
+
local function get_meta_llm_content(task)
local url_content = "Url domains: no urls found"
if task:has_urls() then
@@ -353,40 +394,36 @@ local function default_llm_check(task)
local upstream
- local function on_reply(err, code, body)
+ local results = {}
- if err then
- rspamd_logger.errx(task, 'request failed: %s', err)
- upstream:fail()
- return
- end
+ local function gen_reply_closure(model, idx)
+ return function(err, code, body)
+ results[idx].checked = true
+ if err then
+ rspamd_logger.errx(task, '%s: request failed: %s', model, err)
+ upstream:fail()
+ check_consensus(task, results)
+ return
+ end
- upstream:ok()
- lua_util.debugm(N, task, "got reply: %s", body)
- if code ~= 200 then
- rspamd_logger.errx(task, 'bad reply: %s', body)
- return
- end
+ upstream:ok()
+ lua_util.debugm(N, task, "%s: got reply: %s", model, body)
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'bad reply: %s', body)
+ return
+ end
- local reply = settings.reply_conversion(task, body)
- if not reply then
- return
- end
+ local reply = settings.reply_conversion(task, body)
- if reply > 0.75 then
- task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_spam")
- end
- elseif reply < 0.25 then
- task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_ham")
+ results[idx].model = model
+
+ if reply then
+ results[idx].success = true
+ results[idx].probability = reply
end
- else
- lua_util.debugm(N, task, "uncertain result: %s", reply)
- end
+ check_consensus(task, results)
+ end
end
local from_content, url_content = get_meta_llm_content(task)
@@ -424,24 +461,38 @@ local function default_llm_check(task)
body.response_format = { type = "json_object" }
end
+ if type(settings.model) == 'string' then
+ settings.model = { settings.model }
+ end
+
upstream = settings.upstreams:get_upstream_round_robin()
- local http_params = {
- url = settings.url,
- mime_type = 'application/json',
- timeout = settings.timeout,
- log_obj = task,
- callback = on_reply,
- headers = {
- ['Authorization'] = 'Bearer ' .. settings.api_key,
- },
- keepalive = true,
- body = ucl.to_format(body, 'json-compact', true),
- task = task,
- upstream = upstream,
- use_gzip = true,
- }
+ for idx, model in ipairs(settings.model) do
+ results[idx] = {
+ success = false,
+ checked = false
+ }
+ body.model = model
+ local http_params = {
+ url = settings.url,
+ mime_type = 'application/json',
+ timeout = settings.timeout,
+ log_obj = task,
+ callback = gen_reply_closure(model, idx),
+ headers = {
+ ['Authorization'] = 'Bearer ' .. settings.api_key,
+ },
+ keepalive = true,
+ body = ucl.to_format(body, 'json-compact', true),
+ task = task,
+ upstream = upstream,
+ use_gzip = true,
+ }
- rspamd_http.request(http_params)
+ if not rspamd_http.request(http_params) then
+ results[idx].checked = true
+ end
+
+ end
end
local function ollama_check(task)
@@ -460,45 +511,44 @@ local function ollama_check(task)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
local upstream
+ local results = {}
+
+ local function gen_reply_closure(model, idx)
+ return function(err, code, body)
+ results[idx].checked = true
+ if err then
+ rspamd_logger.errx(task, '%s: request failed: %s', model, err)
+ upstream:fail()
+ check_consensus(task, results)
+ return
+ end
- local function on_reply(err, code, body)
-
- if err then
- rspamd_logger.errx(task, 'request failed: %s', err)
- upstream:fail()
- return
- end
+ upstream:ok()
+ lua_util.debugm(N, task, "%s: got reply: %s", model, body)
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'bad reply: %s', body)
+ return
+ end
- upstream:ok()
- lua_util.debugm(N, task, "got reply: %s", body)
- if code ~= 200 then
- rspamd_logger.errx(task, 'bad reply: %s', body)
- return
- end
+ local reply = settings.reply_conversion(task, body)
- local reply = settings.reply_conversion(task, body)
- if not reply then
- return
- end
+ results[idx].model = model
- if reply > 0.75 then
- task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_spam")
- end
- elseif reply < 0.25 then
- task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_ham")
+ if reply then
+ results[idx].success = true
+ results[idx].probability = reply
end
- else
- lua_util.debugm(N, task, "uncertain result: %s", reply)
- end
+ check_consensus(task, results)
+ end
end
local from_content, url_content = get_meta_llm_content(task)
+ if type(settings.model) == 'string' then
+ settings.model = { settings.model }
+ end
+
local body = {
stream = false,
model = settings.model,
@@ -528,26 +578,30 @@ local function ollama_check(task)
}
}
- -- Conditionally add response_format
- if settings.include_response_format then
- body.response_format = { type = "json_object" }
- end
+ for i, model in ipairs(settings.model) do
+ -- Conditionally add response_format
+ if settings.include_response_format then
+ body.response_format = { type = "json_object" }
+ end
- upstream = settings.upstreams:get_upstream_round_robin()
- local http_params = {
- url = settings.url,
- mime_type = 'application/json',
- timeout = settings.timeout,
- log_obj = task,
- callback = on_reply,
- keepalive = true,
- body = ucl.to_format(body, 'json-compact', true),
- task = task,
- upstream = upstream,
- use_gzip = true,
- }
+ body.model = model
+
+ upstream = settings.upstreams:get_upstream_round_robin()
+ local http_params = {
+ url = settings.url,
+ mime_type = 'application/json',
+ timeout = settings.timeout,
+ log_obj = task,
+ callback = gen_reply_closure(model, i),
+ keepalive = true,
+ body = ucl.to_format(body, 'json-compact', true),
+ task = task,
+ upstream = upstream,
+ use_gzip = true,
+ }
- rspamd_http.request(http_params)
+ rspamd_http.request(http_params)
+ end
end
local function gpt_check(task)