aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/libstat/stat_process.c4
-rw-r--r--src/plugins/lua/gpt.lua45
2 files changed, 40 insertions, 9 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index e7b6b43f0..ad976e713 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -884,7 +884,10 @@ rspamd_stat_learn(struct rspamd_task *task,
st_ctx = rspamd_stat_get_ctx();
g_assert(st_ctx != NULL);
+ msg_debug_bayes("learn stage %d has been called", stage);
+
if (st_ctx->classifiers->len == 0) {
+ msg_debug_bayes("no classifiers defined");
task->processed_stages |= stage;
return ret;
}
@@ -894,6 +897,7 @@ rspamd_stat_learn(struct rspamd_task *task,
rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
+ msg_debug_bayes("cache check failed, skip learning");
return RSPAMD_STAT_PROCESS_ERROR;
}
}
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index c982f57a2..ddd2f0186 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -55,6 +55,7 @@ local lua_util = require "lua_util"
local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
+local fun = require "fun"
-- Exclude checks if one of those is found
local default_symbols_to_except = {
@@ -172,17 +173,26 @@ local function default_conversion(task, input)
return
end
- local spam_score = tonumber(first_message)
- if not spam_score then
- rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+ parser = ucl.parser()
+ res, err = parser:parse_string(first_message)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse JSON gpt reply: %s', err)
return
end
- if type(reply.usage) == 'table' then
- rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+ reply = parser:get_object()
+
+ if type(reply) == 'table' and reply.probability then
+ local spam_score = tonumber(reply.probability)
+ if type(reply.usage) == 'table' then
+ rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+ end
+
+ return spam_score
end
- return spam_score
+ rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+ return
end
local function openai_gpt_check(task)
@@ -238,6 +248,19 @@ local function openai_gpt_check(task)
end
+ local url_content = "Url domains: no urls found"
+ if task:has_urls() then
+ local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
+ url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
+ return u:get_tld() or ''
+ end, urls or {})), ', ')
+ end
+
+ local from_or_empty = ((task:get_from('mime') or E)[1] or E)
+ local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr)
+ lua_util.debugm(N, task, "gpt urls: %s", url_content)
+ lua_util.debugm(N, task, "gpt from: %s", from_content)
+
local body = {
model = settings.model,
max_tokens = settings.max_tokens,
@@ -254,7 +277,11 @@ local function openai_gpt_check(task)
},
{
role = 'user',
- content = 'From: ' .. ((task:get_from('mime') or E)[1] or E).name or '',
+ content = from_content,
+ },
+ {
+ role = 'user',
+ content = url_content,
},
{
role = 'user',
@@ -310,9 +337,9 @@ if opts then
end
if not settings.prompt then
- settings.prompt = "You will be provided with a text of the email, " ..
+ settings.prompt = "You will be provided with the email message, " ..
"and your task is to classify its probability to be spam, " ..
- "output resulting probability as a single floating point number from 0.0 to 1.0."
+ "output result as JSON with 'probability' field."
end
if settings.type == 'openai' then