diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2025-02-27 17:15:10 +0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-27 17:15:10 +0600 |
commit | fb26ba69d7ae746fc9b22d5a45a97503d3245820 (patch) | |
tree | 2ea8cb11c79f848d483e8e7991eece10ebe2b5d6 | |
parent | 4aa341918bfe89162c6812669fa07111447080d4 (diff) | |
parent | 1cdafc964390d5ef45973af9d6cde3e57a146056 (diff) | |
download | rspamd-fb26ba69d7ae746fc9b22d5a45a97503d3245820.tar.gz rspamd-fb26ba69d7ae746fc9b22d5a45a97503d3245820.zip |
Merge pull request #5356 from rspamd/vstakhov-gpt-tunes2
More features to GPT plugin
-rw-r--r-- | src/plugins/lua/gpt.lua | 435 |
1 files changed, 357 insertions, 78 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index 971bfbd29..625450fd9 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -15,6 +15,7 @@ limitations under the License. ]] -- local N = "gpt" +local REDIS_PREFIX = "rsllm_" local E = {} if confighelp then @@ -48,10 +49,10 @@ gpt { allow_passthrough = false; # Check messages that are apparent ham (no action and negative score) allow_ham = false; - # default send response_format field { type = "json_object" } - include_response_format = true, # Add header with reason (null to disable) reason_header = "X-GPT-Reason"; + # Use JSON format for response + json = false; } ]]) return @@ -61,6 +62,7 @@ local lua_util = require "lua_util" local rspamd_http = require "rspamd_http" local rspamd_logger = require "rspamd_logger" local lua_mime = require "lua_mime" +local lua_redis = require "lua_redis" local ucl = require "ucl" local fun = require "fun" @@ -75,6 +77,32 @@ local default_symbols_to_except = { BOUNCE = -1, } +local default_extra_symbols = { + GPT_MARKETING = { + score = 0.0, + description = 'GPT model detected marketing content', + category = 'marketing', + }, + GPT_PHISHING = { + score = 3.0, + description = 'GPT model detected phishing content', + category = 'phishing', + }, + GPT_SCAM = { + score = 3.0, + description = 'GPT model detected scam content', + category = 'scam', + }, + GPT_MALWARE = { + score = 3.0, + description = 'GPT model detected malware content', + category = 'malware', + }, +} + +-- Should be filled from extra symbols +local categories_map = {} + local settings = { type = 'openai', api_key = nil, @@ -88,9 +116,14 @@ local settings = { reason_header = nil, url = 'https://api.openai.com/v1/chat/completions', symbols_to_except = nil, + symbols_to_trigger = nil, -- Exclude/include logic allow_passthrough = false, allow_ham = false, + json = false, + redis_cache_expire = 3600 * 24, + extra_symbols = nil, } +local redis_params local function default_condition(task) -- Check result @@ -113,22 +146,44 @@ local function default_condition(task) return false, 'negative score, already decided as ham' end end - -- We also exclude some symbols - for s, required_weight in pairs(settings.symbols_to_except) do - if task:has_symbol(s) then - if required_weight > 0 then - -- Also check score - local sym = task:get_symbol(s) or E - -- Must exist as we checked it before with `has_symbol` - if sym.weight then - if math.abs(sym.weight) >= required_weight then - return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')' + + if settings.symbols_to_except then + for s, required_weight in pairs(settings.symbols_to_except) do + if task:has_symbol(s) then + if required_weight > 0 then + -- Also check score + local sym = task:get_symbol(s) or E + -- Must exist as we checked it before with `has_symbol` + if sym.weight then + if math.abs(sym.weight) >= required_weight then + return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')' + end + end + lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, + sym.weight, required_weight) + else + return false, 'skip as "' .. s .. '" is found' + end + end + end + end + if settings.symbols_to_trigger then + for s, required_weight in pairs(settings.symbols_to_trigger) do + if task:has_symbol(s) then + if required_weight > 0 then + -- Also check score + local sym = task:get_symbol(s) or E + -- Must exist as we checked it before with `has_symbol` + if sym.weight then + if math.abs(sym.weight) < required_weight then + return false, 'skip as "' .. s .. '" is found with low weight (weight: ' .. sym.weight .. ')' + end end + lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, + sym.weight, required_weight) end - lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, - sym.weight, required_weight) else - return false, 'skip as "' .. s .. '" is found' + return false, 'skip as "' .. s .. '" is not found' end end end @@ -152,10 +207,10 @@ local function default_condition(task) local words = sel_part:get_words('norm') nwords = #words if nwords > settings.max_tokens then - return true, table.concat(words, ' ', 1, settings.max_tokens) + return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part end end - return true, sel_part:get_content_oneline() + return true, sel_part:get_content_oneline(), sel_part end local function maybe_extract_json(str) @@ -196,7 +251,7 @@ local function maybe_extract_json(str) return nil end -local function default_conversion(task, input) +local function default_openai_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -252,14 +307,99 @@ local function default_conversion(task, input) rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) end - return spam_score, reply.reason + return spam_score, reply.reason, {} end rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) return end -local function ollama_conversion(task, input) +-- Remove what we don't need +local function clean_reply_line(line) + if not line then + return '' + end + return lua_util.str_trim(line):gsub("^%d%.%s+", "") +end + +-- Assume that we have 3 lines: probability, reason, additional symbols +local function default_openai_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then + rspamd_logger.errx(task, 'no choices in reply') + return + end + + local first_message = reply.choices[1].message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + local lines = lua_util.str_split(first_message, '\n') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + + if spam_score then + return spam_score, reason, categories + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1]) + return +end + +local function default_ollama_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.message) ~= 'table' then + rspamd_logger.errx(task, 'bad message in reply') + return + end + + local first_message = reply.message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + local lines = lua_util.str_split(first_message, '\n') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + + if spam_score then + return spam_score, reason, categories + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1]) + return +end + +local function default_ollama_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -322,7 +462,79 @@ local function ollama_conversion(task, input) return end -local function check_consensus(task, results) +-- Make cache specific to all settings to avoid conflicts +local env_digest = nil + +local function redis_cache_key(sel_part) + if not env_digest then + local hasher = require "rspamd_cryptobox_hash" + local digest = hasher.create() + digest:update(settings.prompt) + digest:update(settings.model) + digest:update(settings.url) + env_digest = digest:hex():sub(1, 4) + end + return string.format('%s%s_%s', REDIS_PREFIX, env_digest, + sel_part:get_mimepart():get_digest():sub(1, 24)) +end + +local function maybe_save_cache(task, result, sel_part) + if not sel_part or not redis_params then + lua_util.debugm(N, task, 'cannot save cache: no part or no redis') + return -- cannot save intentionally + end + + local cache_key = redis_cache_key(sel_part) + lua_util.debugm(N, task, 'saving cache for %s', cache_key) + local result_json = ucl.to_format(result, 'json-compact') + lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, _) + if err then + rspamd_logger.errx(task, 'cannot save cache: %s', err) + end + end, + 'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json }) +end + +local function process_categories(task, categories) + for _, category in ipairs(categories) do + local sym = categories_map[category:lower()] + if sym then + task:insert_result(sym.name, 1.0) + end + end +end + +local function insert_results(task, result, sel_part) + if not result.probability then + rspamd_logger.errx(task, 'no probability in result') + return + end + if result.probability > 0.5 then + task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_spam") + end + + if result.categories then + process_categories(task, result.categories) + end + else + if result.reason and settings.reason_header then + lua_mime.modify_headers(task, + { add = { [settings.reason_header] = { value = 'value', order = 1 } } }) + end + task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_ham") + end + if result.categories then + process_categories(task, result.categories) + end + end + maybe_save_cache(task, result, sel_part) +end + +local function check_consensus_and_insert_results(task, results, sel_part) for _, result in ipairs(results) do if not result.checked then return @@ -348,7 +560,7 @@ local function check_consensus(task, results) end if result.reason then - table.insert(reasons, result.reason) + table.insert(reasons, result) end end end @@ -357,24 +569,19 @@ local function check_consensus(task, results) local reason = reasons[1] or nil if nspam > nham and max_spam_prob > 0.75 then - task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob)) - if settings.autolearn then - task:set_flag("learn_spam") - end - - if reason and settings.reason_header then - lua_mime.modify_headers(task, - { add = { [settings.reason_header] = { value = 'value', order = 1 } } }) - end + insert_results(task, { + probability = max_spam_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) elseif nham > nspam and max_ham_prob < 0.25 then - task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob)) - if settings.autolearn then - task:set_flag("learn_ham") - end - if reason and settings.reason_header then - lua_mime.modify_headers(task, - { add = { [settings.reason_header] = { value = 'value', order = 1 } } }) - end + insert_results(task, { + probability = max_ham_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) else -- No consensus lua_util.debugm(N, task, "no consensus") @@ -399,19 +606,42 @@ local function get_meta_llm_content(task) return url_content, from_content end -local function default_llm_check(task) - local ret, content = settings.condition(task) +local function check_llm_uncached(task, content, sel_part) + return settings.specific_check(task, content, sel_part) +end - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return - end +local function check_llm_cached(task, content, sel_part) + local cache_key = redis_cache_key(sel_part) - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return + local ret = lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, data) + if err then + rspamd_logger.errx(task, 'cannot check cache: %s', err) + check_llm_uncached(task, content, sel_part) + end + + if type(data) == 'string' then + local parser = ucl.parser() + local res, parse_err = parser:parse_string(data) + if not res then + rspamd_logger.errx(task, 'Cannot parse cached response: %s', parse_err) + check_llm_uncached(task, content, sel_part) + else + rspamd_logger.infox(task, 'found cached response %s', cache_key) + insert_results(task, parser:get_object()) + end + else + check_llm_uncached(task, content, sel_part) + end + end, + 'GET', { cache_key }) + + if not ret then + rspamd_logger.errx(task, 'cannot query cache for request') + check_llm_uncached(task, content, sel_part) end +end +local function openai_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream @@ -424,7 +654,7 @@ local function default_llm_check(task) if err then rspamd_logger.errx(task, '%s: request failed: %s', model, err) upstream:fail() - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) return end @@ -435,7 +665,7 @@ local function default_llm_check(task) return end - local reply, reason = settings.reply_conversion(task, body) + local reply, reason, categories = settings.reply_conversion(task, body) results[idx].model = model @@ -443,9 +673,13 @@ local function default_llm_check(task) results[idx].success = true results[idx].probability = reply results[idx].reason = reason + + if categories then + results[idx].categories = categories + end end - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) end end @@ -518,19 +752,7 @@ local function default_llm_check(task) end end -local function ollama_check(task) - local ret, content = settings.condition(task) - - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return - end - - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return - end - +local function ollama_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream @@ -542,7 +764,7 @@ local function ollama_check(task) if err then rspamd_logger.errx(task, '%s: request failed: %s', model, err) upstream:fail() - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) return end @@ -563,7 +785,7 @@ local function ollama_check(task) results[idx].reason = reason end - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) end end @@ -629,39 +851,65 @@ local function ollama_check(task) end local function gpt_check(task) - return settings.specific_check(task) + local ret, content, sel_part = settings.condition(task) + + if not ret then + rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) + return + end + + if not content then + lua_util.debugm(N, task, "no content to send to gpt classification") + return + end + + if sel_part then + -- Check digest + check_llm_cached(task, content, sel_part) + else + check_llm_uncached(task, content) + end end local types_map = { openai = { - check = default_llm_check, + check = openai_check, condition = default_condition, - conversion = default_conversion, + conversion = function(is_json) + return is_json and default_openai_json_conversion or default_openai_plain_conversion + end, require_passkey = true, }, ollama = { check = ollama_check, condition = default_condition, - conversion = ollama_conversion, + conversion = function(is_json) + return is_json and default_ollama_json_conversion or default_ollama_plain_conversion + end, require_passkey = false, }, } -local opts = rspamd_config:get_all_opt('gpt') +local opts = rspamd_config:get_all_opt(N) if opts then + redis_params = lua_redis.parse_redis_server(N, opts) settings = lua_util.override_defaults(settings, opts) - if not settings.prompt then - settings.prompt = "You will be provided with the email message, subject, from and url domains, " .. - "and your task is to evaluate the probability to be spam as number from 0 to 1, " .. - "output result as JSON with 'probability' field and " .. - "add 'reason' field with 1 sentence description why you have made that decision." + if redis_params then + lua_redis.register_prefix(REDIS_PREFIX .. '*', N, + 'Cache of LLM requests', { + type = 'string', + }) end if not settings.symbols_to_except then settings.symbols_to_except = default_symbols_to_except end + if not settings.extra_symbols then + settings.extra_symbols = default_extra_symbols + end + local llm_type = types_map[settings.type] if not llm_type then rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type) @@ -679,7 +927,7 @@ if opts then if settings.reply_conversion then settings.reply_conversion = load(settings.reply_conversion)() else - settings.reply_conversion = llm_type.conversion + settings.reply_conversion = llm_type.conversion(settings.json) end if not settings.api_key and llm_type.require_passkey then @@ -703,7 +951,7 @@ if opts then name = 'GPT_SPAM', type = 'virtual', parent = id, - score = 5.0, + score = 3.0, }) rspamd_config:register_symbol({ name = 'GPT_HAM', @@ -711,4 +959,35 @@ if opts then parent = id, score = -2.0, }) + + if settings.extra_symbols then + for sym, data in pairs(settings.extra_symbols) do + rspamd_config:register_symbol({ + name = sym, + type = 'virtual', + parent = id, + score = data.score, + description = data.description, + }) + data.name = sym + categories_map[data.category] = data + end + end + + if not settings.prompt then + if settings.extra_symbols then + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 3 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing strongest red flag\n" .. + "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ') + else + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 2 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing strongest red flag\n" + end + end end |