aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2025-02-27 17:15:10 +0600
committerGitHub <noreply@github.com>2025-02-27 17:15:10 +0600
commitfb26ba69d7ae746fc9b22d5a45a97503d3245820 (patch)
tree2ea8cb11c79f848d483e8e7991eece10ebe2b5d6
parent4aa341918bfe89162c6812669fa07111447080d4 (diff)
parent1cdafc964390d5ef45973af9d6cde3e57a146056 (diff)
downloadrspamd-fb26ba69d7ae746fc9b22d5a45a97503d3245820.tar.gz
rspamd-fb26ba69d7ae746fc9b22d5a45a97503d3245820.zip
Merge pull request #5356 from rspamd/vstakhov-gpt-tunes2
More features to GPT plugin
-rw-r--r--src/plugins/lua/gpt.lua435
1 files changed, 357 insertions, 78 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index 971bfbd29..625450fd9 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -15,6 +15,7 @@ limitations under the License.
]] --
local N = "gpt"
+local REDIS_PREFIX = "rsllm_"
local E = {}
if confighelp then
@@ -48,10 +49,10 @@ gpt {
allow_passthrough = false;
# Check messages that are apparent ham (no action and negative score)
allow_ham = false;
- # default send response_format field { type = "json_object" }
- include_response_format = true,
# Add header with reason (null to disable)
reason_header = "X-GPT-Reason";
+ # Use JSON format for response
+ json = false;
}
]])
return
@@ -61,6 +62,7 @@ local lua_util = require "lua_util"
local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local lua_mime = require "lua_mime"
+local lua_redis = require "lua_redis"
local ucl = require "ucl"
local fun = require "fun"
@@ -75,6 +77,32 @@ local default_symbols_to_except = {
BOUNCE = -1,
}
+local default_extra_symbols = {
+ GPT_MARKETING = {
+ score = 0.0,
+ description = 'GPT model detected marketing content',
+ category = 'marketing',
+ },
+ GPT_PHISHING = {
+ score = 3.0,
+ description = 'GPT model detected phishing content',
+ category = 'phishing',
+ },
+ GPT_SCAM = {
+ score = 3.0,
+ description = 'GPT model detected scam content',
+ category = 'scam',
+ },
+ GPT_MALWARE = {
+ score = 3.0,
+ description = 'GPT model detected malware content',
+ category = 'malware',
+ },
+}
+
+-- Should be filled from extra symbols
+local categories_map = {}
+
local settings = {
type = 'openai',
api_key = nil,
@@ -88,9 +116,14 @@ local settings = {
reason_header = nil,
url = 'https://api.openai.com/v1/chat/completions',
symbols_to_except = nil,
+ symbols_to_trigger = nil, -- Exclude/include logic
allow_passthrough = false,
allow_ham = false,
+ json = false,
+ redis_cache_expire = 3600 * 24,
+ extra_symbols = nil,
}
+local redis_params
local function default_condition(task)
-- Check result
@@ -113,22 +146,44 @@ local function default_condition(task)
return false, 'negative score, already decided as ham'
end
end
- -- We also exclude some symbols
- for s, required_weight in pairs(settings.symbols_to_except) do
- if task:has_symbol(s) then
- if required_weight > 0 then
- -- Also check score
- local sym = task:get_symbol(s) or E
- -- Must exist as we checked it before with `has_symbol`
- if sym.weight then
- if math.abs(sym.weight) >= required_weight then
- return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')'
+
+ if settings.symbols_to_except then
+ for s, required_weight in pairs(settings.symbols_to_except) do
+ if task:has_symbol(s) then
+ if required_weight > 0 then
+ -- Also check score
+ local sym = task:get_symbol(s) or E
+ -- Must exist as we checked it before with `has_symbol`
+ if sym.weight then
+ if math.abs(sym.weight) >= required_weight then
+ return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')'
+ end
+ end
+ lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s,
+ sym.weight, required_weight)
+ else
+ return false, 'skip as "' .. s .. '" is found'
+ end
+ end
+ end
+ end
+ if settings.symbols_to_trigger then
+ for s, required_weight in pairs(settings.symbols_to_trigger) do
+ if task:has_symbol(s) then
+ if required_weight > 0 then
+ -- Also check score
+ local sym = task:get_symbol(s) or E
+ -- Must exist as we checked it before with `has_symbol`
+ if sym.weight then
+ if math.abs(sym.weight) < required_weight then
+ return false, 'skip as "' .. s .. '" is found with low weight (weight: ' .. sym.weight .. ')'
+ end
end
+ lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s,
+ sym.weight, required_weight)
end
- lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s,
- sym.weight, required_weight)
else
- return false, 'skip as "' .. s .. '" is found'
+ return false, 'skip as "' .. s .. '" is not found'
end
end
end
@@ -152,10 +207,10 @@ local function default_condition(task)
local words = sel_part:get_words('norm')
nwords = #words
if nwords > settings.max_tokens then
- return true, table.concat(words, ' ', 1, settings.max_tokens)
+ return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part
end
end
- return true, sel_part:get_content_oneline()
+ return true, sel_part:get_content_oneline(), sel_part
end
local function maybe_extract_json(str)
@@ -196,7 +251,7 @@ local function maybe_extract_json(str)
return nil
end
-local function default_conversion(task, input)
+local function default_openai_json_conversion(task, input)
local parser = ucl.parser()
local res, err = parser:parse_string(input)
if not res then
@@ -252,14 +307,99 @@ local function default_conversion(task, input)
rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
end
- return spam_score, reply.reason
+ return spam_score, reply.reason, {}
end
rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
return
end
-local function ollama_conversion(task, input)
+-- Remove what we don't need
+local function clean_reply_line(line)
+ if not line then
+ return ''
+ end
+ return lua_util.str_trim(line):gsub("^%d%.%s+", "")
+end
+
+-- Assume that we have 3 lines: probability, reason, additional symbols
+local function default_openai_plain_conversion(task, input)
+ local parser = ucl.parser()
+ local res, err = parser:parse_string(input)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+ return
+ end
+ local reply = parser:get_object()
+ if not reply then
+ rspamd_logger.errx(task, 'cannot get object from reply')
+ return
+ end
+
+ if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then
+ rspamd_logger.errx(task, 'no choices in reply')
+ return
+ end
+
+ local first_message = reply.choices[1].message.content
+
+ if not first_message then
+ rspamd_logger.errx(task, 'no content in the first message')
+ return
+ end
+ local lines = lua_util.str_split(first_message, '\n')
+ local first_line = clean_reply_line(lines[1])
+ local spam_score = tonumber(first_line)
+ local reason = clean_reply_line(lines[2])
+ local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
+
+ if spam_score then
+ return spam_score, reason, categories
+ end
+
+ rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+ return
+end
+
+local function default_ollama_plain_conversion(task, input)
+ local parser = ucl.parser()
+ local res, err = parser:parse_string(input)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+ return
+ end
+ local reply = parser:get_object()
+ if not reply then
+ rspamd_logger.errx(task, 'cannot get object from reply')
+ return
+ end
+
+ if type(reply.message) ~= 'table' then
+ rspamd_logger.errx(task, 'bad message in reply')
+ return
+ end
+
+ local first_message = reply.message.content
+
+ if not first_message then
+ rspamd_logger.errx(task, 'no content in the first message')
+ return
+ end
+ local lines = lua_util.str_split(first_message, '\n')
+ local first_line = clean_reply_line(lines[1])
+ local spam_score = tonumber(first_line)
+ local reason = clean_reply_line(lines[2])
+ local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
+
+ if spam_score then
+ return spam_score, reason, categories
+ end
+
+ rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1])
+ return
+end
+
+local function default_ollama_json_conversion(task, input)
local parser = ucl.parser()
local res, err = parser:parse_string(input)
if not res then
@@ -322,7 +462,79 @@ local function ollama_conversion(task, input)
return
end
-local function check_consensus(task, results)
+-- Make cache specific to all settings to avoid conflicts
+local env_digest = nil
+
+local function redis_cache_key(sel_part)
+ if not env_digest then
+ local hasher = require "rspamd_cryptobox_hash"
+ local digest = hasher.create()
+ digest:update(settings.prompt)
+ digest:update(settings.model)
+ digest:update(settings.url)
+ env_digest = digest:hex():sub(1, 4)
+ end
+ return string.format('%s%s_%s', REDIS_PREFIX, env_digest,
+ sel_part:get_mimepart():get_digest():sub(1, 24))
+end
+
+local function maybe_save_cache(task, result, sel_part)
+ if not sel_part or not redis_params then
+ lua_util.debugm(N, task, 'cannot save cache: no part or no redis')
+ return -- cannot save intentionally
+ end
+
+ local cache_key = redis_cache_key(sel_part)
+ lua_util.debugm(N, task, 'saving cache for %s', cache_key)
+ local result_json = ucl.to_format(result, 'json-compact')
+ lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, _)
+ if err then
+ rspamd_logger.errx(task, 'cannot save cache: %s', err)
+ end
+ end,
+ 'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json })
+end
+
+local function process_categories(task, categories)
+ for _, category in ipairs(categories) do
+ local sym = categories_map[category:lower()]
+ if sym then
+ task:insert_result(sym.name, 1.0)
+ end
+ end
+end
+
+local function insert_results(task, result, sel_part)
+ if not result.probability then
+ rspamd_logger.errx(task, 'no probability in result')
+ return
+ end
+ if result.probability > 0.5 then
+ task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability))
+ if settings.autolearn then
+ task:set_flag("learn_spam")
+ end
+
+ if result.categories then
+ process_categories(task, result.categories)
+ end
+ else
+ if result.reason and settings.reason_header then
+ lua_mime.modify_headers(task,
+ { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
+ end
+ task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability))
+ if settings.autolearn then
+ task:set_flag("learn_ham")
+ end
+ if result.categories then
+ process_categories(task, result.categories)
+ end
+ end
+ maybe_save_cache(task, result, sel_part)
+end
+
+local function check_consensus_and_insert_results(task, results, sel_part)
for _, result in ipairs(results) do
if not result.checked then
return
@@ -348,7 +560,7 @@ local function check_consensus(task, results)
end
if result.reason then
- table.insert(reasons, result.reason)
+ table.insert(reasons, result)
end
end
end
@@ -357,24 +569,19 @@ local function check_consensus(task, results)
local reason = reasons[1] or nil
if nspam > nham and max_spam_prob > 0.75 then
- task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob))
- if settings.autolearn then
- task:set_flag("learn_spam")
- end
-
- if reason and settings.reason_header then
- lua_mime.modify_headers(task,
- { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
- end
+ insert_results(task, {
+ probability = max_spam_prob,
+ reason = reason.reason,
+ categories = reason.categories,
+ },
+ sel_part)
elseif nham > nspam and max_ham_prob < 0.25 then
- task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob))
- if settings.autolearn then
- task:set_flag("learn_ham")
- end
- if reason and settings.reason_header then
- lua_mime.modify_headers(task,
- { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
- end
+ insert_results(task, {
+ probability = max_ham_prob,
+ reason = reason.reason,
+ categories = reason.categories,
+ },
+ sel_part)
else
-- No consensus
lua_util.debugm(N, task, "no consensus")
@@ -399,19 +606,42 @@ local function get_meta_llm_content(task)
return url_content, from_content
end
-local function default_llm_check(task)
- local ret, content = settings.condition(task)
+local function check_llm_uncached(task, content, sel_part)
+ return settings.specific_check(task, content, sel_part)
+end
- if not ret then
- rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
- return
- end
+local function check_llm_cached(task, content, sel_part)
+ local cache_key = redis_cache_key(sel_part)
- if not content then
- lua_util.debugm(N, task, "no content to send to gpt classification")
- return
+ local ret = lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, data)
+ if err then
+ rspamd_logger.errx(task, 'cannot check cache: %s', err)
+ check_llm_uncached(task, content, sel_part)
+ end
+
+ if type(data) == 'string' then
+ local parser = ucl.parser()
+ local res, parse_err = parser:parse_string(data)
+ if not res then
+ rspamd_logger.errx(task, 'Cannot parse cached response: %s', parse_err)
+ check_llm_uncached(task, content, sel_part)
+ else
+ rspamd_logger.infox(task, 'found cached response %s', cache_key)
+ insert_results(task, parser:get_object())
+ end
+ else
+ check_llm_uncached(task, content, sel_part)
+ end
+ end,
+ 'GET', { cache_key })
+
+ if not ret then
+ rspamd_logger.errx(task, 'cannot query cache for request')
+ check_llm_uncached(task, content, sel_part)
end
+end
+local function openai_check(task, content, sel_part)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
local upstream
@@ -424,7 +654,7 @@ local function default_llm_check(task)
if err then
rspamd_logger.errx(task, '%s: request failed: %s', model, err)
upstream:fail()
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
return
end
@@ -435,7 +665,7 @@ local function default_llm_check(task)
return
end
- local reply, reason = settings.reply_conversion(task, body)
+ local reply, reason, categories = settings.reply_conversion(task, body)
results[idx].model = model
@@ -443,9 +673,13 @@ local function default_llm_check(task)
results[idx].success = true
results[idx].probability = reply
results[idx].reason = reason
+
+ if categories then
+ results[idx].categories = categories
+ end
end
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
end
end
@@ -518,19 +752,7 @@ local function default_llm_check(task)
end
end
-local function ollama_check(task)
- local ret, content = settings.condition(task)
-
- if not ret then
- rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
- return
- end
-
- if not content then
- lua_util.debugm(N, task, "no content to send to gpt classification")
- return
- end
-
+local function ollama_check(task, content, sel_part)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
local upstream
@@ -542,7 +764,7 @@ local function ollama_check(task)
if err then
rspamd_logger.errx(task, '%s: request failed: %s', model, err)
upstream:fail()
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
return
end
@@ -563,7 +785,7 @@ local function ollama_check(task)
results[idx].reason = reason
end
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
end
end
@@ -629,39 +851,65 @@ local function ollama_check(task)
end
local function gpt_check(task)
- return settings.specific_check(task)
+ local ret, content, sel_part = settings.condition(task)
+
+ if not ret then
+ rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
+ return
+ end
+
+ if not content then
+ lua_util.debugm(N, task, "no content to send to gpt classification")
+ return
+ end
+
+ if sel_part then
+ -- Check digest
+ check_llm_cached(task, content, sel_part)
+ else
+ check_llm_uncached(task, content)
+ end
end
local types_map = {
openai = {
- check = default_llm_check,
+ check = openai_check,
condition = default_condition,
- conversion = default_conversion,
+ conversion = function(is_json)
+ return is_json and default_openai_json_conversion or default_openai_plain_conversion
+ end,
require_passkey = true,
},
ollama = {
check = ollama_check,
condition = default_condition,
- conversion = ollama_conversion,
+ conversion = function(is_json)
+ return is_json and default_ollama_json_conversion or default_ollama_plain_conversion
+ end,
require_passkey = false,
},
}
-local opts = rspamd_config:get_all_opt('gpt')
+local opts = rspamd_config:get_all_opt(N)
if opts then
+ redis_params = lua_redis.parse_redis_server(N, opts)
settings = lua_util.override_defaults(settings, opts)
- if not settings.prompt then
- settings.prompt = "You will be provided with the email message, subject, from and url domains, " ..
- "and your task is to evaluate the probability to be spam as number from 0 to 1, " ..
- "output result as JSON with 'probability' field and " ..
- "add 'reason' field with 1 sentence description why you have made that decision."
+ if redis_params then
+ lua_redis.register_prefix(REDIS_PREFIX .. '*', N,
+ 'Cache of LLM requests', {
+ type = 'string',
+ })
end
if not settings.symbols_to_except then
settings.symbols_to_except = default_symbols_to_except
end
+ if not settings.extra_symbols then
+ settings.extra_symbols = default_extra_symbols
+ end
+
local llm_type = types_map[settings.type]
if not llm_type then
rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
@@ -679,7 +927,7 @@ if opts then
if settings.reply_conversion then
settings.reply_conversion = load(settings.reply_conversion)()
else
- settings.reply_conversion = llm_type.conversion
+ settings.reply_conversion = llm_type.conversion(settings.json)
end
if not settings.api_key and llm_type.require_passkey then
@@ -703,7 +951,7 @@ if opts then
name = 'GPT_SPAM',
type = 'virtual',
parent = id,
- score = 5.0,
+ score = 3.0,
})
rspamd_config:register_symbol({
name = 'GPT_HAM',
@@ -711,4 +959,35 @@ if opts then
parent = id,
score = -2.0,
})
+
+ if settings.extra_symbols then
+ for sym, data in pairs(settings.extra_symbols) do
+ rspamd_config:register_symbol({
+ name = sym,
+ type = 'virtual',
+ parent = id,
+ score = data.score,
+ description = data.description,
+ })
+ data.name = sym
+ categories_map[data.category] = data
+ end
+ end
+
+ if not settings.prompt then
+ if settings.extra_symbols then
+ settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+ "FROM and url domains. Evaluate spam probability (0-1). " ..
+ "Output ONLY 3 lines:\n" ..
+ "1. Numeric score (0.00-1.00)\n" ..
+ "2. One-sentence reason citing strongest red flag\n" ..
+ "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ')
+ else
+ settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+ "FROM and url domains. Evaluate spam probability (0-1). " ..
+ "Output ONLY 2 lines:\n" ..
+ "1. Numeric score (0.00-1.00)\n" ..
+ "2. One-sentence reason citing strongest red flag\n"
+ end
+ end
end