aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-07-01 19:15:10 +0100
committerVsevolod Stakhov <vsevolod@rspamd.com>2024-07-01 19:15:10 +0100
commit35ceeaff18fd5c7020bd79adbb8e92568d7d6591 (patch)
treed9b2a68478c8ea68169088d230e11b72f532881c /src/plugins/lua
parent1d865d80a50c4cea8b2aa61222996aa2d2b0a83d (diff)
downloadrspamd-35ceeaff18fd5c7020bd79adbb8e92568d7d6591.tar.gz
rspamd-35ceeaff18fd5c7020bd79adbb8e92568d7d6591.zip
[Project] Trying to test various things with GPT
Diffstat (limited to 'src/plugins/lua')
-rw-r--r--src/plugins/lua/gpt.lua45
1 files changed, 36 insertions, 9 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index c982f57a2..ddd2f0186 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -55,6 +55,7 @@ local lua_util = require "lua_util"
local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
+local fun = require "fun"
-- Exclude checks if one of those is found
local default_symbols_to_except = {
@@ -172,17 +173,26 @@ local function default_conversion(task, input)
return
end
- local spam_score = tonumber(first_message)
- if not spam_score then
- rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+ parser = ucl.parser()
+ res, err = parser:parse_string(first_message)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse JSON gpt reply: %s', err)
return
end
- if type(reply.usage) == 'table' then
- rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+ reply = parser:get_object()
+
+ if type(reply) == 'table' and reply.probability then
+ local spam_score = tonumber(reply.probability)
+ if type(reply.usage) == 'table' then
+ rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+ end
+
+ return spam_score
end
- return spam_score
+ rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+ return
end
local function openai_gpt_check(task)
@@ -238,6 +248,19 @@ local function openai_gpt_check(task)
end
+ local url_content = "Url domains: no urls found"
+ if task:has_urls() then
+ local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
+ url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u)
+ return u:get_tld() or ''
+ end, urls or {})), ', ')
+ end
+
+ local from_or_empty = ((task:get_from('mime') or E)[1] or E)
+ local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr)
+ lua_util.debugm(N, task, "gpt urls: %s", url_content)
+ lua_util.debugm(N, task, "gpt from: %s", from_content)
+
local body = {
model = settings.model,
max_tokens = settings.max_tokens,
@@ -254,7 +277,11 @@ local function openai_gpt_check(task)
},
{
role = 'user',
- content = 'From: ' .. ((task:get_from('mime') or E)[1] or E).name or '',
+ content = from_content,
+ },
+ {
+ role = 'user',
+ content = url_content,
},
{
role = 'user',
@@ -310,9 +337,9 @@ if opts then
end
if not settings.prompt then
- settings.prompt = "You will be provided with a text of the email, " ..
+ settings.prompt = "You will be provided with the email message, " ..
"and your task is to classify its probability to be spam, " ..
- "output resulting probability as a single floating point number from 0.0 to 1.0."
+ "output result as JSON with 'probability' field."
end
if settings.type == 'openai' then