diff options
-rw-r--r-- | src/libstat/stat_process.c | 4 | ||||
-rw-r--r-- | src/plugins/lua/gpt.lua | 45 |
2 files changed, 40 insertions, 9 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index e7b6b43f0..ad976e713 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -884,7 +884,10 @@ rspamd_stat_learn(struct rspamd_task *task, st_ctx = rspamd_stat_get_ctx(); g_assert(st_ctx != NULL); + msg_debug_bayes("learn stage %d has been called", stage); + if (st_ctx->classifiers->len == 0) { + msg_debug_bayes("no classifiers defined"); task->processed_stages |= stage; return ret; } @@ -894,6 +897,7 @@ rspamd_stat_learn(struct rspamd_task *task, rspamd_stat_preprocess(st_ctx, task, TRUE, spam); if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) { + msg_debug_bayes("cache check failed, skip learning"); return RSPAMD_STAT_PROCESS_ERROR; } } diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index c982f57a2..ddd2f0186 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -55,6 +55,7 @@ local lua_util = require "lua_util" local rspamd_http = require "rspamd_http" local rspamd_logger = require "rspamd_logger" local ucl = require "ucl" +local fun = require "fun" -- Exclude checks if one of those is found local default_symbols_to_except = { @@ -172,17 +173,26 @@ local function default_conversion(task, input) return end - local spam_score = tonumber(first_message) - if not spam_score then - rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) + parser = ucl.parser() + res, err = parser:parse_string(first_message) + if not res then + rspamd_logger.errx(task, 'cannot parse JSON gpt reply: %s', err) return end - if type(reply.usage) == 'table' then - rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + reply = parser:get_object() + + if type(reply) == 'table' and reply.probability then + local spam_score = tonumber(reply.probability) + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + + return spam_score end - return spam_score + rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) + return end local function openai_gpt_check(task) @@ -238,6 +248,19 @@ local function openai_gpt_check(task) end + local url_content = "Url domains: no urls found" + if task:has_urls() then + local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 } + url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u) + return u:get_tld() or '' + end, urls or {})), ', ') + end + + local from_or_empty = ((task:get_from('mime') or E)[1] or E) + local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr) + lua_util.debugm(N, task, "gpt urls: %s", url_content) + lua_util.debugm(N, task, "gpt from: %s", from_content) + local body = { model = settings.model, max_tokens = settings.max_tokens, @@ -254,7 +277,11 @@ local function openai_gpt_check(task) }, { role = 'user', - content = 'From: ' .. ((task:get_from('mime') or E)[1] or E).name or '', + content = from_content, + }, + { + role = 'user', + content = url_content, }, { role = 'user', @@ -310,9 +337,9 @@ if opts then end if not settings.prompt then - settings.prompt = "You will be provided with a text of the email, " .. + settings.prompt = "You will be provided with the email message, " .. "and your task is to classify its probability to be spam, " .. - "output resulting probability as a single floating point number from 0.0 to 1.0." + "output result as JSON with 'probability' field." end if settings.type == 'openai' then |