aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/gpt.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/gpt.lua')
-rw-r--r--src/plugins/lua/gpt.lua73
1 files changed, 44 insertions, 29 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index 5d1cf5e06..5776791a1 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -20,9 +20,9 @@ local E = {}
if confighelp then
rspamd_config:add_example(nil, 'gpt',
- "Performs postfiltering using GPT model",
- [[
-gpt {
+ "Performs postfiltering using GPT model",
+ [[
+ gpt {
# Supported types: openai, ollama
type = "openai";
# Your key to access the API
@@ -53,7 +53,7 @@ gpt {
reason_header = "X-GPT-Reason";
# Use JSON format for response
json = false;
-}
+ }
]])
return
end
@@ -162,7 +162,7 @@ local function default_condition(task)
end
end
lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s,
- sym.weight, required_weight)
+ sym.weight, required_weight)
else
return false, 'skip as "' .. s .. '" is found'
end
@@ -182,7 +182,7 @@ local function default_condition(task)
end
end
lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s,
- sym.weight, required_weight)
+ sym.weight, required_weight)
end
else
return false, 'skip as "' .. s .. '" is not found'
@@ -301,7 +301,7 @@ local function default_openai_json_conversion(task, input)
elseif reply.probability == "low" then
spam_score = 0.1
else
- rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability)
+ rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability)
end
end
@@ -355,14 +355,27 @@ local function default_openai_plain_conversion(task, input)
local reason = clean_reply_line(lines[2])
local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
+ if type(reply.usage) == 'table' then
+ rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+ end
+
if spam_score then
return spam_score, reason, categories
end
- rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+ rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message)
return
end
+-- Helper function to remove <think>...</think> and trim leading newlines
+local function clean_gpt_response(text)
+ -- Remove <think>...</think> including multiline
+ text = text:gsub("<think>.-</think>", "")
+ -- Trim leading whitespace and newlines
+ text = text:gsub("^%s*\n*", "")
+ return text
+end
+
local function default_ollama_plain_conversion(task, input)
local parser = ucl.parser()
local res, err = parser:parse_string(input)
@@ -387,6 +400,10 @@ local function default_ollama_plain_conversion(task, input)
rspamd_logger.errx(task, 'no content in the first message')
return
end
+
+ -- Clean message
+ first_message = clean_gpt_response(first_message)
+
local lines = lua_util.str_split(first_message, '\n')
local first_line = clean_reply_line(lines[1])
local spam_score = tonumber(first_line)
@@ -397,7 +414,7 @@ local function default_ollama_plain_conversion(task, input)
return spam_score, reason, categories
end
- rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1])
+ rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message)
return
end
@@ -449,7 +466,7 @@ local function default_ollama_json_conversion(task, input)
elseif reply.probability == "low" then
spam_score = 0.1
else
- rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability)
+ rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability)
end
end
@@ -477,7 +494,7 @@ local function redis_cache_key(sel_part)
env_digest = digest:hex():sub(1, 4)
end
return string.format('%s_%s', env_digest,
- sel_part:get_mimepart():get_digest():sub(1, 24))
+ sel_part:get_mimepart():get_digest():sub(1, 24))
end
local function process_categories(task, categories)
@@ -514,9 +531,9 @@ local function insert_results(task, result, sel_part)
end
end
if result.reason and settings.reason_header then
- lua_mime.modify_headers(task,
- { add = { [settings.reason_header] = { value = tostring(result.reason), order = 1 } } })
- end
+ lua_mime.modify_headers(task,
+ { add = { [settings.reason_header] = { value = tostring(result.reason), order = 1 } } })
+ end
if cache_context then
lua_cache.cache_set(task, redis_cache_key(sel_part), result, cache_context)
@@ -540,12 +557,12 @@ local function check_consensus_and_insert_results(task, results, sel_part)
nspam = nspam + 1
max_spam_prob = math.max(max_spam_prob, result.probability)
lua_util.debugm(N, task, "model: %s; spam: %s; reason: '%s'",
- result.model, result.probability, result.reason)
+ result.model, result.probability, result.reason)
else
nham = nham + 1
max_ham_prob = math.min(max_ham_prob, result.probability)
lua_util.debugm(N, task, "model: %s; ham: %s; reason: '%s'",
- result.model, result.probability, result.reason)
+ result.model, result.probability, result.reason)
end
if result.reason then
@@ -559,23 +576,22 @@ local function check_consensus_and_insert_results(task, results, sel_part)
if nspam > nham and max_spam_prob > 0.75 then
insert_results(task, {
- probability = max_spam_prob,
- reason = reason.reason,
- categories = reason.categories,
- },
- sel_part)
+ probability = max_spam_prob,
+ reason = reason.reason,
+ categories = reason.categories,
+ },
+ sel_part)
elseif nham > nspam and max_ham_prob < 0.25 then
insert_results(task, {
- probability = max_ham_prob,
- reason = reason.reason,
- categories = reason.categories,
- },
- sel_part)
+ probability = max_ham_prob,
+ reason = reason.reason,
+ categories = reason.categories,
+ },
+ sel_part)
else
-- No consensus
lua_util.debugm(N, task, "no consensus")
end
-
end
local function get_meta_llm_content(task)
@@ -674,7 +690,7 @@ local function openai_check(task, content, sel_part)
},
{
role = 'user',
- content = 'Subject: ' .. task:get_subject() or '',
+ content = 'Subject: ' .. (task:get_subject() or ''),
},
{
role = 'user',
@@ -726,7 +742,6 @@ local function openai_check(task, content, sel_part)
if not rspamd_http.request(http_params) then
results[idx].checked = true
end
-
end
end