aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2025-02-25 11:02:23 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2025-02-25 11:02:23 +0000
commit20ab502c23bb64d70cd56a10fae3345ad6d28d98 (patch)
treeb39fb336218f4fc8fc396c32e43344e2c0cc0bee
parent7414eea791ecb8dba7ce7f50945caba0e42f3363 (diff)
downloadrspamd-20ab502c23bb64d70cd56a10fae3345ad6d28d98.tar.gz
rspamd-20ab502c23bb64d70cd56a10fae3345ad6d28d98.zip
[Feature] Improve prompt and use plaintext instead of JSON
-rw-r--r--src/plugins/lua/gpt.lua115
1 files changed, 104 insertions, 11 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index 0fb6123e1..270d0fdfc 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -50,6 +50,8 @@ gpt {
allow_ham = false;
# Add header with reason (null to disable)
reason_header = "X-GPT-Reason";
+ # Use JSON format for response
+ json = false;
}
]])
return
@@ -89,6 +91,7 @@ local settings = {
symbols_to_trigger = nil, -- Exclude/include logic
allow_passthrough = false,
allow_ham = false,
+ json = false,
}
local function default_condition(task)
@@ -217,7 +220,7 @@ local function maybe_extract_json(str)
return nil
end
-local function default_conversion(task, input)
+local function default_openai_json_conversion(task, input)
local parser = ucl.parser()
local res, err = parser:parse_string(input)
if not res then
@@ -273,14 +276,99 @@ local function default_conversion(task, input)
rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
end
- return spam_score, reply.reason
+ return spam_score, reply.reason, {}
end
rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
return
end
-local function ollama_conversion(task, input)
+-- Assume that we have 3 lines: probability, reason, additional symbols
+local function default_openai_plain_conversion(task, input)
+ local parser = ucl.parser()
+ local res, err = parser:parse_string(input)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+ return
+ end
+ local reply = parser:get_object()
+ if not reply then
+ rspamd_logger.errx(task, 'cannot get object from reply')
+ return
+ end
+
+ if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then
+ rspamd_logger.errx(task, 'no choices in reply')
+ return
+ end
+
+ local first_message = reply.choices[1].message.content
+
+ if not first_message then
+ rspamd_logger.errx(task, 'no content in the first message')
+ return
+ end
+ local lines = lua_util.str_split(first_message, '\n')
+ local first_line = lines[1] or ''
+ local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
+ :gsub("[^%d%.]", "")
+ :gsub("%.$", "")
+ :gsub("%.%..*", "")
+ local spam_score = tonumber(cleaned_line)
+ local reason = lines[2]
+ local symbols = lua_util.str_split(lines[3] or '', ',')
+
+ if spam_score then
+ return spam_score, reason, symbols
+ end
+
+ rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+ return
+end
+
+local function default_ollama_plain_conversion(task, input)
+ local parser = ucl.parser()
+ local res, err = parser:parse_string(input)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+ return
+ end
+ local reply = parser:get_object()
+ if not reply then
+ rspamd_logger.errx(task, 'cannot get object from reply')
+ return
+ end
+
+ if type(reply.message) ~= 'table' then
+ rspamd_logger.errx(task, 'bad message in reply')
+ return
+ end
+
+ local first_message = reply.message.content
+
+ if not first_message then
+ rspamd_logger.errx(task, 'no content in the first message')
+ return
+ end
+ local lines = lua_util.str_split(first_message, '\n')
+ local first_line = lines[1] or ''
+ local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
+ :gsub("[^%d%.]", "")
+ :gsub("%.$", "")
+ :gsub("%.%..*", "")
+ local spam_score = tonumber(cleaned_line)
+ local reason = lines[2]
+ local symbols = lua_util.str_split(lines[3] or '', ',')
+
+ if spam_score then
+ return spam_score, reason, symbols
+ end
+
+ rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+ return
+end
+
+local function default_ollama_json_conversion(task, input)
local parser = ucl.parser()
local res, err = parser:parse_string(input)
if not res then
@@ -456,7 +544,7 @@ local function default_llm_check(task)
return
end
- local reply, reason = settings.reply_conversion(task, body)
+ local reply, reason, _symbols = settings.reply_conversion(task, body)
results[idx].model = model
@@ -657,13 +745,17 @@ local types_map = {
openai = {
check = default_llm_check,
condition = default_condition,
- conversion = default_conversion,
+ conversion = function(is_json)
+ return is_json and default_openai_json_conversion or default_openai_plain_conversion
+ end,
require_passkey = true,
},
ollama = {
check = ollama_check,
condition = default_condition,
- conversion = ollama_conversion,
+ conversion = function(is_json)
+ return is_json and default_ollama_json_conversion or default_ollama_plain_conversion
+ end,
require_passkey = false,
},
}
@@ -673,10 +765,11 @@ if opts then
settings = lua_util.override_defaults(settings, opts)
if not settings.prompt then
- settings.prompt = "You will be provided with the email message, subject, from and url domains, " ..
- "and your task is to evaluate the probability to be spam as number from 0 to 1, " ..
- "output result as JSON with 'probability' field and " ..
- "add 'reason' field with 1 sentence description why you have made that decision."
+ settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+ "FROM and url domains. Evaluate spam probability (0-1). " ..
+ "Output ONLY 2 lines:\n" ..
+ "1. Numeric score (0.00-1.00)\n" ..
+ "2. One-sentence reason citing strongest red flag"
end
if not settings.symbols_to_except then
@@ -700,7 +793,7 @@ if opts then
if settings.reply_conversion then
settings.reply_conversion = load(settings.reply_conversion)()
else
- settings.reply_conversion = llm_type.conversion
+ settings.reply_conversion = llm_type.conversion(settings.json)
end
if not settings.api_key and llm_type.require_passkey then