aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/gpt.lua140
1 files changed, 108 insertions, 32 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index f605b702a..625450fd9 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -77,6 +77,32 @@ local default_symbols_to_except = {
BOUNCE = -1,
}
+local default_extra_symbols = {
+ GPT_MARKETING = {
+ score = 0.0,
+ description = 'GPT model detected marketing content',
+ category = 'marketing',
+ },
+ GPT_PHISHING = {
+ score = 3.0,
+ description = 'GPT model detected phishing content',
+ category = 'phishing',
+ },
+ GPT_SCAM = {
+ score = 3.0,
+ description = 'GPT model detected scam content',
+ category = 'scam',
+ },
+ GPT_MALWARE = {
+ score = 3.0,
+ description = 'GPT model detected malware content',
+ category = 'malware',
+ },
+}
+
+-- Should be filled from extra symbols
+local categories_map = {}
+
local settings = {
type = 'openai',
api_key = nil,
@@ -95,6 +121,7 @@ local settings = {
allow_ham = false,
json = false,
redis_cache_expire = 3600 * 24,
+ extra_symbols = nil,
}
local redis_params
@@ -287,6 +314,14 @@ local function default_openai_json_conversion(task, input)
return
end
+-- Remove what we don't need
+local function clean_reply_line(line)
+ if not line then
+ return ''
+ end
+ return lua_util.str_trim(line):gsub("^%d%.%s+", "")
+end
+
-- Assume that we have 3 lines: probability, reason, additional symbols
local function default_openai_plain_conversion(task, input)
local parser = ucl.parser()
@@ -313,17 +348,13 @@ local function default_openai_plain_conversion(task, input)
return
end
local lines = lua_util.str_split(first_message, '\n')
- local first_line = lines[1] or ''
- local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
- :gsub("[^%d%.]", "")
- :gsub("%.$", "")
- :gsub("%.%..*", "")
- local spam_score = tonumber(cleaned_line)
- local reason = lines[2]
- local symbols = lua_util.str_split(lines[3] or '', ',')
+ local first_line = clean_reply_line(lines[1])
+ local spam_score = tonumber(first_line)
+ local reason = clean_reply_line(lines[2])
+ local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
if spam_score then
- return spam_score, reason, symbols
+ return spam_score, reason, categories
end
rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
@@ -355,20 +386,16 @@ local function default_ollama_plain_conversion(task, input)
return
end
local lines = lua_util.str_split(first_message, '\n')
- local first_line = lines[1] or ''
- local cleaned_line = first_line:gsub("^[%d%p]%s?%f[%d]", "")
- :gsub("[^%d%.]", "")
- :gsub("%.$", "")
- :gsub("%.%..*", "")
- local spam_score = tonumber(cleaned_line)
- local reason = lines[2]
- local symbols = lua_util.str_split(lines[3] or '', ',')
+ local first_line = clean_reply_line(lines[1])
+ local spam_score = tonumber(first_line)
+ local reason = clean_reply_line(lines[2])
+ local categories = lua_util.str_split(clean_reply_line(lines[3]), ',')
if spam_score then
- return spam_score, reason, symbols
+ return spam_score, reason, categories
end
- rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1])
+ rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1])
return
end
@@ -468,6 +495,15 @@ local function maybe_save_cache(task, result, sel_part)
'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json })
end
+local function process_categories(task, categories)
+ for _, category in ipairs(categories) do
+ local sym = categories_map[category:lower()]
+ if sym then
+ task:insert_result(sym.name, 1.0)
+ end
+ end
+end
+
local function insert_results(task, result, sel_part)
if not result.probability then
rspamd_logger.errx(task, 'no probability in result')
@@ -478,6 +514,10 @@ local function insert_results(task, result, sel_part)
if settings.autolearn then
task:set_flag("learn_spam")
end
+
+ if result.categories then
+ process_categories(task, result.categories)
+ end
else
if result.reason and settings.reason_header then
lua_mime.modify_headers(task,
@@ -487,6 +527,9 @@ local function insert_results(task, result, sel_part)
if settings.autolearn then
task:set_flag("learn_ham")
end
+ if result.categories then
+ process_categories(task, result.categories)
+ end
end
maybe_save_cache(task, result, sel_part)
end
@@ -517,7 +560,7 @@ local function check_consensus_and_insert_results(task, results, sel_part)
end
if result.reason then
- table.insert(reasons, result.reason)
+ table.insert(reasons, result)
end
end
end
@@ -528,13 +571,15 @@ local function check_consensus_and_insert_results(task, results, sel_part)
if nspam > nham and max_spam_prob > 0.75 then
insert_results(task, {
probability = max_spam_prob,
- reason = reason,
+ reason = reason.reason,
+ categories = reason.categories,
},
sel_part)
elseif nham > nspam and max_ham_prob < 0.25 then
insert_results(task, {
probability = max_ham_prob,
- reason = reason,
+ reason = reason.reason,
+ categories = reason.categories,
},
sel_part)
else
@@ -620,7 +665,7 @@ local function openai_check(task, content, sel_part)
return
end
- local reply, reason, _symbols = settings.reply_conversion(task, body)
+ local reply, reason, categories = settings.reply_conversion(task, body)
results[idx].model = model
@@ -628,6 +673,10 @@ local function openai_check(task, content, sel_part)
results[idx].success = true
results[idx].probability = reply
results[idx].reason = reason
+
+ if categories then
+ results[idx].categories = categories
+ end
end
check_consensus_and_insert_results(task, results, sel_part)
@@ -853,18 +902,14 @@ if opts then
})
end
- if not settings.prompt then
- settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
- "FROM and url domains. Evaluate spam probability (0-1). " ..
- "Output ONLY 2 lines:\n" ..
- "1. Numeric score (0.00-1.00)\n" ..
- "2. One-sentence reason citing strongest red flag"
- end
-
if not settings.symbols_to_except then
settings.symbols_to_except = default_symbols_to_except
end
+ if not settings.extra_symbols then
+ settings.extra_symbols = default_extra_symbols
+ end
+
local llm_type = types_map[settings.type]
if not llm_type then
rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
@@ -906,7 +951,7 @@ if opts then
name = 'GPT_SPAM',
type = 'virtual',
parent = id,
- score = 5.0,
+ score = 3.0,
})
rspamd_config:register_symbol({
name = 'GPT_HAM',
@@ -914,4 +959,35 @@ if opts then
parent = id,
score = -2.0,
})
+
+ if settings.extra_symbols then
+ for sym, data in pairs(settings.extra_symbols) do
+ rspamd_config:register_symbol({
+ name = sym,
+ type = 'virtual',
+ parent = id,
+ score = data.score,
+ description = data.description,
+ })
+ data.name = sym
+ categories_map[data.category] = data
+ end
+ end
+
+ if not settings.prompt then
+ if settings.extra_symbols then
+ settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+ "FROM and url domains. Evaluate spam probability (0-1). " ..
+ "Output ONLY 3 lines:\n" ..
+ "1. Numeric score (0.00-1.00)\n" ..
+ "2. One-sentence reason citing strongest red flag\n" ..
+ "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ')
+ else
+ settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
+ "FROM and url domains. Evaluate spam probability (0-1). " ..
+ "Output ONLY 2 lines:\n" ..
+ "1. Numeric score (0.00-1.00)\n" ..
+ "2. One-sentence reason citing strongest red flag\n"
+ end
+ end
end