diff options
Diffstat (limited to 'src/plugins/lua/gpt.lua')
-rw-r--r-- | src/plugins/lua/gpt.lua | 668 |
1 files changed, 519 insertions, 149 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index feccae73f..331dbbce2 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -15,13 +15,14 @@ limitations under the License. ]] -- local N = "gpt" +local REDIS_PREFIX = "rsllm" local E = {} if confighelp then rspamd_config:add_example(nil, 'gpt', - "Performs postfiltering using GPT model", - [[ -gpt { + "Performs postfiltering using GPT model", + [[ + gpt { # Supported types: openai, ollama type = "openai"; # Your key to access the API @@ -48,7 +49,11 @@ gpt { allow_passthrough = false; # Check messages that are apparent ham (no action and negative score) allow_ham = false; -} + # Add header with reason (null to disable) + reason_header = "X-GPT-Reason"; + # Use JSON format for response + json = false; + } ]]) return end @@ -57,8 +62,10 @@ local lua_util = require "lua_util" local rspamd_http = require "rspamd_http" local rspamd_logger = require "rspamd_logger" local lua_mime = require "lua_mime" +local lua_redis = require "lua_redis" local ucl = require "ucl" local fun = require "fun" +local lua_cache = require "lua_cache" -- Exclude checks if one of those is found local default_symbols_to_except = { @@ -71,6 +78,32 @@ local default_symbols_to_except = { BOUNCE = -1, } +local default_extra_symbols = { + GPT_MARKETING = { + score = 0.0, + description = 'GPT model detected marketing content', + category = 'marketing', + }, + GPT_PHISHING = { + score = 3.0, + description = 'GPT model detected phishing content', + category = 'phishing', + }, + GPT_SCAM = { + score = 3.0, + description = 'GPT model detected scam content', + category = 'scam', + }, + GPT_MALWARE = { + score = 3.0, + description = 'GPT model detected malware content', + category = 'malware', + }, +} + +-- Should be filled from extra symbols +local categories_map = {} + local settings = { type = 'openai', api_key = nil, @@ -81,11 +114,18 @@ local settings = { prompt = nil, condition = nil, autolearn = false, + reason_header = nil, url = 'https://api.openai.com/v1/chat/completions', - symbols_to_except = default_symbols_to_except, + symbols_to_except = nil, + symbols_to_trigger = nil, -- Exclude/include logic allow_passthrough = false, allow_ham = false, + json = false, + extra_symbols = nil, + cache_prefix = REDIS_PREFIX, } +local redis_params +local cache_context local function default_condition(task) -- Check result @@ -108,22 +148,44 @@ local function default_condition(task) return false, 'negative score, already decided as ham' end end - -- We also exclude some symbols - for s, required_weight in pairs(settings.symbols_to_except) do - if task:has_symbol(s) then - if required_weight > 0 then - -- Also check score - local sym = task:get_symbol(s) or E - -- Must exist as we checked it before with `has_symbol` - if sym.weight then - if math.abs(sym.weight) >= required_weight then - return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')' + + if settings.symbols_to_except then + for s, required_weight in pairs(settings.symbols_to_except) do + if task:has_symbol(s) then + if required_weight > 0 then + -- Also check score + local sym = task:get_symbol(s) or E + -- Must exist as we checked it before with `has_symbol` + if sym.weight then + if math.abs(sym.weight) >= required_weight then + return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')' + end end + lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, + sym.weight, required_weight) + else + return false, 'skip as "' .. s .. '" is found' end - lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, + end + end + end + if settings.symbols_to_trigger then + for s, required_weight in pairs(settings.symbols_to_trigger) do + if task:has_symbol(s) then + if required_weight > 0 then + -- Also check score + local sym = task:get_symbol(s) or E + -- Must exist as we checked it before with `has_symbol` + if sym.weight then + if math.abs(sym.weight) < required_weight then + return false, 'skip as "' .. s .. '" is found with low weight (weight: ' .. sym.weight .. ')' + end + end + lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, sym.weight, required_weight) + end else - return false, 'skip as "' .. s .. '" is found' + return false, 'skip as "' .. s .. '" is not found' end end end @@ -147,10 +209,10 @@ local function default_condition(task) local words = sel_part:get_words('norm') nwords = #words if nwords > settings.max_tokens then - return true, table.concat(words, ' ', 1, settings.max_tokens) + return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part end end - return true, sel_part:get_content_oneline() + return true, sel_part:get_content_oneline(), sel_part end local function maybe_extract_json(str) @@ -191,7 +253,16 @@ local function maybe_extract_json(str) return nil end -local function default_conversion(task, input) +-- Helper function to remove <think>...</think> and trim leading newlines +local function clean_gpt_response(text) + -- Remove <think>...</think> including multiline + text = text:gsub("<think>.-</think>", "") + -- Trim leading whitespace and newlines + text = text:gsub("^%s*\n*", "") + return text +end + +local function default_openai_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -239,7 +310,7 @@ local function default_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -247,14 +318,111 @@ local function default_conversion(task, input) rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) end - return spam_score + return spam_score, reply.reason, {} end rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) return end -local function ollama_conversion(task, input) +-- Remove what we don't need +local function clean_reply_line(line) + if not line then + return '' + end + return lua_util.str_trim(line):gsub("^%d%.%s+", "") +end + +-- Assume that we have 3 lines: probability, reason, additional symbols +local function default_openai_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then + rspamd_logger.errx(task, 'no choices in reply') + return + end + + local first_message = reply.choices[1].message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + + -- Clean message + first_message = clean_gpt_response(first_message) + + local lines = lua_util.str_split(first_message, '\n') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + + if spam_score then + return spam_score, reason, categories + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) + return +end + +local function default_ollama_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.message) ~= 'table' then + rspamd_logger.errx(task, 'bad message in reply') + return + end + + local first_message = reply.message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + + -- Clean message + first_message = clean_gpt_response(first_message) + + local lines = lua_util.str_split(first_message, '\n') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + + if spam_score then + return spam_score, reason, categories + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) + return +end + +local function default_ollama_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -302,7 +470,7 @@ local function ollama_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -310,13 +478,126 @@ local function ollama_conversion(task, input) rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) end - return spam_score + return spam_score, reply.reason end rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) return end +-- Make cache specific to all settings to avoid conflicts +local env_digest = nil + +local function redis_cache_key(sel_part) + if not env_digest then + local hasher = require "rspamd_cryptobox_hash" + local digest = hasher.create() + digest:update(settings.prompt) + digest:update(settings.model) + digest:update(settings.url) + env_digest = digest:hex():sub(1, 4) + end + return string.format('%s_%s', env_digest, + sel_part:get_mimepart():get_digest():sub(1, 24)) +end + +local function process_categories(task, categories) + for _, category in ipairs(categories) do + local sym = categories_map[category:lower()] + if sym then + task:insert_result(sym.name, 1.0) + end + end +end + +local function insert_results(task, result, sel_part) + if not result.probability then + rspamd_logger.errx(task, 'no probability in result') + return + end + + if result.probability > 0.5 then + task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_spam") + end + + if result.categories then + process_categories(task, result.categories) + end + else + task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_ham") + end + if result.categories then + process_categories(task, result.categories) + end + end + if result.reason and settings.reason_header then + lua_mime.modify_headers(task, + { add = { [settings.reason_header] = { value = tostring(result.reason), order = 1 } } }) + end + + if cache_context then + lua_cache.cache_set(task, redis_cache_key(sel_part), result, cache_context) + end +end + +local function check_consensus_and_insert_results(task, results, sel_part) + for _, result in ipairs(results) do + if not result.checked then + return + end + end + + local nspam, nham = 0, 0 + local max_spam_prob, max_ham_prob = 0, 0 + local reasons = {} + + for _, result in ipairs(results) do + if result.success then + if result.probability > 0.5 then + nspam = nspam + 1 + max_spam_prob = math.max(max_spam_prob, result.probability) + lua_util.debugm(N, task, "model: %s; spam: %s; reason: '%s'", + result.model, result.probability, result.reason) + else + nham = nham + 1 + max_ham_prob = math.min(max_ham_prob, result.probability) + lua_util.debugm(N, task, "model: %s; ham: %s; reason: '%s'", + result.model, result.probability, result.reason) + end + + if result.reason then + table.insert(reasons, result) + end + end + end + + lua_util.shuffle(reasons) + local reason = reasons[1] or nil + + if nspam > nham and max_spam_prob > 0.75 then + insert_results(task, { + probability = max_spam_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) + elseif nham > nspam and max_ham_prob < 0.25 then + insert_results(task, { + probability = max_ham_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) + else + -- No consensus + lua_util.debugm(N, task, "no consensus") + end +end + local function get_meta_llm_content(task) local url_content = "Url domains: no urls found" if task:has_urls() then @@ -334,57 +615,70 @@ local function get_meta_llm_content(task) return url_content, from_content end -local function default_llm_check(task) - local ret, content = settings.condition(task) +local function check_llm_uncached(task, content, sel_part) + return settings.specific_check(task, content, sel_part) +end - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return - end +local function check_llm_cached(task, content, sel_part) + local cache_key = redis_cache_key(sel_part) - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return - end + lua_cache.cache_get(task, cache_key, cache_context, settings.timeout * 1.5, function() + check_llm_uncached(task, content, sel_part) + end, function(_, err, data) + if err then + rspamd_logger.errx(task, 'cannot get cache: %s', err) + check_llm_uncached(task, content, sel_part) + end + if data then + rspamd_logger.infox(task, 'found cached response %s', cache_key) + insert_results(task, data, sel_part) + else + check_llm_uncached(task, content, sel_part) + end + end) +end + +local function openai_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream - local function on_reply(err, code, body) + local results = {} - if err then - rspamd_logger.errx(task, 'request failed: %s', err) - upstream:fail() - return - end + local function gen_reply_closure(model, idx) + return function(err, code, body) + results[idx].checked = true + if err then + rspamd_logger.errx(task, '%s: request failed: %s', model, err) + upstream:fail() + check_consensus_and_insert_results(task, results, sel_part) + return + end - upstream:ok() - lua_util.debugm(N, task, "got reply: %s", body) - if code ~= 200 then - rspamd_logger.errx(task, 'bad reply: %s', body) - return - end + upstream:ok() + lua_util.debugm(N, task, "%s: got reply: %s", model, body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end - local reply = settings.reply_conversion(task, body) - if not reply then - return - end + local reply, reason, categories = settings.reply_conversion(task, body) - if reply > 0.75 then - task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_spam") - end - elseif reply < 0.25 then - task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_ham") + results[idx].model = model + + if reply then + results[idx].success = true + results[idx].probability = reply + results[idx].reason = reason + + if categories then + results[idx].categories = categories + end end - else - lua_util.debugm(N, task, "uncertain result: %s", reply) - end + check_consensus_and_insert_results(task, results, sel_part) + end end local from_content, url_content = get_meta_llm_content(task) @@ -393,7 +687,6 @@ local function default_llm_check(task) model = settings.model, max_tokens = settings.max_tokens, temperature = settings.temperature, - response_format = { type = "json_object" }, messages = { { role = 'system', @@ -401,7 +694,7 @@ local function default_llm_check(task) }, { role = 'user', - content = 'Subject: ' .. task:get_subject() or '', + content = 'Subject: ' .. (task:get_subject() or ''), }, { role = 'user', @@ -418,87 +711,92 @@ local function default_llm_check(task) } } - upstream = settings.upstreams:get_upstream_round_robin() - local http_params = { - url = settings.url, - mime_type = 'application/json', - timeout = settings.timeout, - log_obj = task, - callback = on_reply, - headers = { - ['Authorization'] = 'Bearer ' .. settings.api_key, - }, - keepalive = true, - body = ucl.to_format(body, 'json-compact', true), - task = task, - upstream = upstream, - use_gzip = true, - } - - rspamd_http.request(http_params) -end - -local function ollama_check(task) - local ret, content = settings.condition(task) + -- Conditionally add response_format + if settings.include_response_format then + body.response_format = { type = "json_object" } + end - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return + if type(settings.model) == 'string' then + settings.model = { settings.model } end - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return + upstream = settings.upstreams:get_upstream_round_robin() + for idx, model in ipairs(settings.model) do + results[idx] = { + success = false, + checked = false + } + body.model = model + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = gen_reply_closure(model, idx), + headers = { + ['Authorization'] = 'Bearer ' .. settings.api_key, + }, + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } + + if not rspamd_http.request(http_params) then + results[idx].checked = true + end end +end +local function ollama_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream + local results = {} + + local function gen_reply_closure(model, idx) + return function(err, code, body) + results[idx].checked = true + if err then + rspamd_logger.errx(task, '%s: request failed: %s', model, err) + upstream:fail() + check_consensus_and_insert_results(task, results, sel_part) + return + end - local function on_reply(err, code, body) + upstream:ok() + lua_util.debugm(N, task, "%s: got reply: %s", model, body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end - if err then - rspamd_logger.errx(task, 'request failed: %s', err) - upstream:fail() - return - end + local reply, reason = settings.reply_conversion(task, body) - upstream:ok() - lua_util.debugm(N, task, "got reply: %s", body) - if code ~= 200 then - rspamd_logger.errx(task, 'bad reply: %s', body) - return - end + results[idx].model = model - local reply = settings.reply_conversion(task, body) - if not reply then - return - end - - if reply > 0.75 then - task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_spam") + if reply then + results[idx].success = true + results[idx].probability = reply + results[idx].reason = reason end - elseif reply < 0.25 then - task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_ham") - end - else - lua_util.debugm(N, task, "uncertain result: %s", reply) - end + check_consensus_and_insert_results(task, results, sel_part) + end end local from_content, url_content = get_meta_llm_content(task) + if type(settings.model) == 'string' then + settings.model = { settings.model } + end + local body = { stream = false, model = settings.model, max_tokens = settings.max_tokens, temperature = settings.temperature, - response_format = { type = "json_object" }, messages = { { role = 'system', @@ -523,50 +821,91 @@ local function ollama_check(task) } } - upstream = settings.upstreams:get_upstream_round_robin() - local http_params = { - url = settings.url, - mime_type = 'application/json', - timeout = settings.timeout, - log_obj = task, - callback = on_reply, - keepalive = true, - body = ucl.to_format(body, 'json-compact', true), - task = task, - upstream = upstream, - use_gzip = true, - } + for i, model in ipairs(settings.model) do + -- Conditionally add response_format + if settings.include_response_format then + body.response_format = { type = "json_object" } + end + + results[i] = { + success = false, + checked = false + } + body.model = model + + upstream = settings.upstreams:get_upstream_round_robin() + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = gen_reply_closure(model, i), + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } - rspamd_http.request(http_params) + rspamd_http.request(http_params) + end end local function gpt_check(task) - return settings.specific_check(task) + local ret, content, sel_part = settings.condition(task) + + if not ret then + rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) + return + end + + if not content then + lua_util.debugm(N, task, "no content to send to gpt classification") + return + end + + if sel_part then + -- Check digest + check_llm_cached(task, content, sel_part) + else + check_llm_uncached(task, content) + end end local types_map = { openai = { - check = default_llm_check, + check = openai_check, condition = default_condition, - conversion = default_conversion, + conversion = function(is_json) + return is_json and default_openai_json_conversion or default_openai_plain_conversion + end, require_passkey = true, }, ollama = { check = ollama_check, condition = default_condition, - conversion = ollama_conversion, + conversion = function(is_json) + return is_json and default_ollama_json_conversion or default_ollama_plain_conversion + end, require_passkey = false, }, } -local opts = rspamd_config:get_all_opt('gpt') +local opts = rspamd_config:get_all_opt(N) if opts then + redis_params = lua_redis.parse_redis_server(N, opts) settings = lua_util.override_defaults(settings, opts) - if not settings.prompt then - settings.prompt = "You will be provided with the email message, subject, from and url domains, " .. - "and your task is to evaluate the probability to be spam as number from 0 to 1, " .. - "output result as JSON with 'probability' field." + if redis_params then + cache_context = lua_cache.create_cache_context(redis_params, settings, N) + end + + if not settings.symbols_to_except then + settings.symbols_to_except = default_symbols_to_except + end + + if not settings.extra_symbols then + settings.extra_symbols = default_extra_symbols end local llm_type = types_map[settings.type] @@ -586,7 +925,7 @@ if opts then if settings.reply_conversion then settings.reply_conversion = load(settings.reply_conversion)() else - settings.reply_conversion = llm_type.conversion + settings.reply_conversion = llm_type.conversion(settings.json) end if not settings.api_key and llm_type.require_passkey then @@ -610,7 +949,7 @@ if opts then name = 'GPT_SPAM', type = 'virtual', parent = id, - score = 5.0, + score = 3.0, }) rspamd_config:register_symbol({ name = 'GPT_HAM', @@ -618,4 +957,35 @@ if opts then parent = id, score = -2.0, }) -end
\ No newline at end of file + + if settings.extra_symbols then + for sym, data in pairs(settings.extra_symbols) do + rspamd_config:register_symbol({ + name = sym, + type = 'virtual', + parent = id, + score = data.score, + description = data.description, + }) + data.name = sym + categories_map[data.category] = data + end + end + + if not settings.prompt then + if settings.extra_symbols then + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 3 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" .. + "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ') + else + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 2 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" + end + end +end |