diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2025-02-25 11:02:23 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2025-02-25 11:02:23 +0000 |
commit | 20ab502c23bb64d70cd56a10fae3345ad6d28d98 (patch) | |
tree | b39fb336218f4fc8fc396c32e43344e2c0cc0bee | |
parent | 7414eea791ecb8dba7ce7f50945caba0e42f3363 (diff) | |
download | rspamd-20ab502c23bb64d70cd56a10fae3345ad6d28d98.tar.gz rspamd-20ab502c23bb64d70cd56a10fae3345ad6d28d98.zip |
[Feature] Improve prompt and use plaintext instead of JSON
-rw-r--r-- | src/plugins/lua/gpt.lua | 115 |
1 files changed, 104 insertions, 11 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index 0fb6123e1..270d0fdfc 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -50,6 +50,8 @@ gpt { allow_ham = false; # Add header with reason (null to disable) reason_header = "X-GPT-Reason"; + # Use JSON format for response + json = false; } ]]) return @@ -89,6 +91,7 @@ local settings = { symbols_to_trigger = nil, -- Exclude/include logic allow_passthrough = false, allow_ham = false, + json = false, } local function default_condition(task) @@ -217,7 +220,7 @@ local function maybe_extract_json(str) return nil end -local function default_conversion(task, input) +local function default_openai_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -273,14 +276,99 @@ local function default_conversion(task, input) rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) end - return spam_score, reply.reason + return spam_score, reply.reason, {} end rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) return end -local function ollama_conversion(task, input) +-- Assume that we have 3 lines: probability, reason, additional symbols +local function default_openai_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then + rspamd_logger.errx(task, 'no choices in reply') + return + end + + local first_message = reply.choices[1].message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + local lines = lua_util.str_split(first_message, '\n') + local first_line = lines[1] or '' + local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "") + :gsub("[^%d%.]", "") + :gsub("%.$", "") + :gsub("%.%..*", "") + local spam_score = tonumber(cleaned_line) + local reason = lines[2] + local symbols = lua_util.str_split(lines[3] or '', ',') + + if spam_score then + return spam_score, reason, symbols + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1]) + return +end + +local function default_ollama_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.message) ~= 'table' then + rspamd_logger.errx(task, 'bad message in reply') + return + end + + local first_message = reply.message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + local lines = lua_util.str_split(first_message, '\n') + local first_line = lines[1] or '' + local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "") + :gsub("[^%d%.]", "") + :gsub("%.$", "") + :gsub("%.%..*", "") + local spam_score = tonumber(cleaned_line) + local reason = lines[2] + local symbols = lua_util.str_split(lines[3] or '', ',') + + if spam_score then + return spam_score, reason, symbols + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1]) + return +end + +local function default_ollama_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -456,7 +544,7 @@ local function default_llm_check(task) return end - local reply, reason = settings.reply_conversion(task, body) + local reply, reason, _symbols = settings.reply_conversion(task, body) results[idx].model = model @@ -657,13 +745,17 @@ local types_map = { openai = { check = default_llm_check, condition = default_condition, - conversion = default_conversion, + conversion = function(is_json) + return is_json and default_openai_json_conversion or default_openai_plain_conversion + end, require_passkey = true, }, ollama = { check = ollama_check, condition = default_condition, - conversion = ollama_conversion, + conversion = function(is_json) + return is_json and default_ollama_json_conversion or default_ollama_plain_conversion + end, require_passkey = false, }, } @@ -673,10 +765,11 @@ if opts then settings = lua_util.override_defaults(settings, opts) if not settings.prompt then - settings.prompt = "You will be provided with the email message, subject, from and url domains, " .. - "and your task is to evaluate the probability to be spam as number from 0 to 1, " .. - "output result as JSON with 'probability' field and " .. - "add 'reason' field with 1 sentence description why you have made that decision." + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 2 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing strongest red flag" end if not settings.symbols_to_except then @@ -700,7 +793,7 @@ if opts then if settings.reply_conversion then settings.reply_conversion = load(settings.reply_conversion)() else - settings.reply_conversion = llm_type.conversion + settings.reply_conversion = llm_type.conversion(settings.json) end if not settings.api_key and llm_type.require_passkey then |