diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/gpt.lua | 140 |
1 files changed, 108 insertions, 32 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index f605b702a..625450fd9 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -77,6 +77,32 @@ local default_symbols_to_except = { BOUNCE = -1, } +local default_extra_symbols = { + GPT_MARKETING = { + score = 0.0, + description = 'GPT model detected marketing content', + category = 'marketing', + }, + GPT_PHISHING = { + score = 3.0, + description = 'GPT model detected phishing content', + category = 'phishing', + }, + GPT_SCAM = { + score = 3.0, + description = 'GPT model detected scam content', + category = 'scam', + }, + GPT_MALWARE = { + score = 3.0, + description = 'GPT model detected malware content', + category = 'malware', + }, +} + +-- Should be filled from extra symbols +local categories_map = {} + local settings = { type = 'openai', api_key = nil, @@ -95,6 +121,7 @@ local settings = { allow_ham = false, json = false, redis_cache_expire = 3600 * 24, + extra_symbols = nil, } local redis_params @@ -287,6 +314,14 @@ local function default_openai_json_conversion(task, input) return end +-- Remove what we don't need +local function clean_reply_line(line) + if not line then + return '' + end + return lua_util.str_trim(line):gsub("^%d%.%s+", "") +end + -- Assume that we have 3 lines: probability, reason, additional symbols local function default_openai_plain_conversion(task, input) local parser = ucl.parser() @@ -313,17 +348,13 @@ local function default_openai_plain_conversion(task, input) return end local lines = lua_util.str_split(first_message, '\n') - local first_line = lines[1] or '' - local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "") - :gsub("[^%d%.]", "") - :gsub("%.$", "") - :gsub("%.%..*", "") - local spam_score = tonumber(cleaned_line) - local reason = lines[2] - local symbols = lua_util.str_split(lines[3] or '', ',') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') if spam_score then - return spam_score, reason, symbols + return spam_score, reason, categories end rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1]) @@ -355,20 +386,16 @@ local function default_ollama_plain_conversion(task, input) return end local lines = lua_util.str_split(first_message, '\n') - local first_line = lines[1] or '' - local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "") - :gsub("[^%d%.]", "") - :gsub("%.$", "") - :gsub("%.%..*", "") - local spam_score = tonumber(cleaned_line) - local reason = lines[2] - local symbols = lua_util.str_split(lines[3] or '', ',') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') if spam_score then - return spam_score, reason, symbols + return spam_score, reason, categories end - rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1]) + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1]) return end @@ -468,6 +495,15 @@ local function maybe_save_cache(task, result, sel_part) 'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json }) end +local function process_categories(task, categories) + for _, category in ipairs(categories) do + local sym = categories_map[category:lower()] + if sym then + task:insert_result(sym.name, 1.0) + end + end +end + local function insert_results(task, result, sel_part) if not result.probability then rspamd_logger.errx(task, 'no probability in result') @@ -478,6 +514,10 @@ local function insert_results(task, result, sel_part) if settings.autolearn then task:set_flag("learn_spam") end + + if result.categories then + process_categories(task, result.categories) + end else if result.reason and settings.reason_header then lua_mime.modify_headers(task, @@ -487,6 +527,9 @@ local function insert_results(task, result, sel_part) if settings.autolearn then task:set_flag("learn_ham") end + if result.categories then + process_categories(task, result.categories) + end end maybe_save_cache(task, result, sel_part) end @@ -517,7 +560,7 @@ local function check_consensus_and_insert_results(task, results, sel_part) end if result.reason then - table.insert(reasons, result.reason) + table.insert(reasons, result) end end end @@ -528,13 +571,15 @@ local function check_consensus_and_insert_results(task, results, sel_part) if nspam > nham and max_spam_prob > 0.75 then insert_results(task, { probability = max_spam_prob, - reason = reason, + reason = reason.reason, + categories = reason.categories, }, sel_part) elseif nham > nspam and max_ham_prob < 0.25 then insert_results(task, { probability = max_ham_prob, - reason = reason, + reason = reason.reason, + categories = reason.categories, }, sel_part) else @@ -620,7 +665,7 @@ local function openai_check(task, content, sel_part) return end - local reply, reason, _symbols = settings.reply_conversion(task, body) + local reply, reason, categories = settings.reply_conversion(task, body) results[idx].model = model @@ -628,6 +673,10 @@ local function openai_check(task, content, sel_part) results[idx].success = true results[idx].probability = reply results[idx].reason = reason + + if categories then + results[idx].categories = categories + end end check_consensus_and_insert_results(task, results, sel_part) @@ -853,18 +902,14 @@ if opts then }) end - if not settings.prompt then - settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. - "FROM and url domains. Evaluate spam probability (0-1). " .. - "Output ONLY 2 lines:\n" .. - "1. Numeric score (0.00-1.00)\n" .. - "2. One-sentence reason citing strongest red flag" - end - if not settings.symbols_to_except then settings.symbols_to_except = default_symbols_to_except end + if not settings.extra_symbols then + settings.extra_symbols = default_extra_symbols + end + local llm_type = types_map[settings.type] if not llm_type then rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type) @@ -906,7 +951,7 @@ if opts then name = 'GPT_SPAM', type = 'virtual', parent = id, - score = 5.0, + score = 3.0, }) rspamd_config:register_symbol({ name = 'GPT_HAM', @@ -914,4 +959,35 @@ if opts then parent = id, score = -2.0, }) + + if settings.extra_symbols then + for sym, data in pairs(settings.extra_symbols) do + rspamd_config:register_symbol({ + name = sym, + type = 'virtual', + parent = id, + score = data.score, + description = data.description, + }) + data.name = sym + categories_map[data.category] = data + end + end + + if not settings.prompt then + if settings.extra_symbols then + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 3 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing strongest red flag\n" .. + "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ') + else + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 2 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing strongest red flag\n" + end + end end |