diff options
Diffstat (limited to 'src/plugins/lua/gpt.lua')
-rw-r--r-- | src/plugins/lua/gpt.lua | 73 |
1 files changed, 44 insertions, 29 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index 5d1cf5e06..5776791a1 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -20,9 +20,9 @@ local E = {} if confighelp then rspamd_config:add_example(nil, 'gpt', - "Performs postfiltering using GPT model", - [[ -gpt { + "Performs postfiltering using GPT model", + [[ + gpt { # Supported types: openai, ollama type = "openai"; # Your key to access the API @@ -53,7 +53,7 @@ gpt { reason_header = "X-GPT-Reason"; # Use JSON format for response json = false; -} + } ]]) return end @@ -162,7 +162,7 @@ local function default_condition(task) end end lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, - sym.weight, required_weight) + sym.weight, required_weight) else return false, 'skip as "' .. s .. '" is found' end @@ -182,7 +182,7 @@ local function default_condition(task) end end lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, - sym.weight, required_weight) + sym.weight, required_weight) end else return false, 'skip as "' .. s .. '" is not found' @@ -301,7 +301,7 @@ local function default_openai_json_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -355,14 +355,27 @@ local function default_openai_plain_conversion(task, input) local reason = clean_reply_line(lines[2]) local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + if spam_score then return spam_score, reason, categories end - rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1]) + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) return end +-- Helper function to remove <think>...</think> and trim leading newlines +local function clean_gpt_response(text) + -- Remove <think>...</think> including multiline + text = text:gsub("<think>.-</think>", "") + -- Trim leading whitespace and newlines + text = text:gsub("^%s*\n*", "") + return text +end + local function default_ollama_plain_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) @@ -387,6 +400,10 @@ local function default_ollama_plain_conversion(task, input) rspamd_logger.errx(task, 'no content in the first message') return end + + -- Clean message + first_message = clean_gpt_response(first_message) + local lines = lua_util.str_split(first_message, '\n') local first_line = clean_reply_line(lines[1]) local spam_score = tonumber(first_line) @@ -397,7 +414,7 @@ local function default_ollama_plain_conversion(task, input) return spam_score, reason, categories end - rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1]) + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) return end @@ -449,7 +466,7 @@ local function default_ollama_json_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -477,7 +494,7 @@ local function redis_cache_key(sel_part) env_digest = digest:hex():sub(1, 4) end return string.format('%s_%s', env_digest, - sel_part:get_mimepart():get_digest():sub(1, 24)) + sel_part:get_mimepart():get_digest():sub(1, 24)) end local function process_categories(task, categories) @@ -514,9 +531,9 @@ local function insert_results(task, result, sel_part) end end if result.reason and settings.reason_header then - lua_mime.modify_headers(task, - { add = { [settings.reason_header] = { value = tostring(result.reason), order = 1 } } }) - end + lua_mime.modify_headers(task, + { add = { [settings.reason_header] = { value = tostring(result.reason), order = 1 } } }) + end if cache_context then lua_cache.cache_set(task, redis_cache_key(sel_part), result, cache_context) @@ -540,12 +557,12 @@ local function check_consensus_and_insert_results(task, results, sel_part) nspam = nspam + 1 max_spam_prob = math.max(max_spam_prob, result.probability) lua_util.debugm(N, task, "model: %s; spam: %s; reason: '%s'", - result.model, result.probability, result.reason) + result.model, result.probability, result.reason) else nham = nham + 1 max_ham_prob = math.min(max_ham_prob, result.probability) lua_util.debugm(N, task, "model: %s; ham: %s; reason: '%s'", - result.model, result.probability, result.reason) + result.model, result.probability, result.reason) end if result.reason then @@ -559,23 +576,22 @@ local function check_consensus_and_insert_results(task, results, sel_part) if nspam > nham and max_spam_prob > 0.75 then insert_results(task, { - probability = max_spam_prob, - reason = reason.reason, - categories = reason.categories, - }, - sel_part) + probability = max_spam_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) elseif nham > nspam and max_ham_prob < 0.25 then insert_results(task, { - probability = max_ham_prob, - reason = reason.reason, - categories = reason.categories, - }, - sel_part) + probability = max_ham_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) else -- No consensus lua_util.debugm(N, task, "no consensus") end - end local function get_meta_llm_content(task) @@ -674,7 +690,7 @@ local function openai_check(task, content, sel_part) }, { role = 'user', - content = 'Subject: ' .. task:get_subject() or '', + content = 'Subject: ' .. (task:get_subject() or ''), }, { role = 'user', @@ -726,7 +742,6 @@ local function openai_check(task, content, sel_part) if not rspamd_http.request(http_params) then results[idx].checked = true end - end end |