diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-07-01 19:15:10 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-07-01 19:15:10 +0100 |
commit | 35ceeaff18fd5c7020bd79adbb8e92568d7d6591 (patch) | |
tree | d9b2a68478c8ea68169088d230e11b72f532881c /src/plugins/lua | |
parent | 1d865d80a50c4cea8b2aa61222996aa2d2b0a83d (diff) | |
download | rspamd-35ceeaff18fd5c7020bd79adbb8e92568d7d6591.tar.gz rspamd-35ceeaff18fd5c7020bd79adbb8e92568d7d6591.zip |
[Project] Trying to test various things with GPT
Diffstat (limited to 'src/plugins/lua')
-rw-r--r-- | src/plugins/lua/gpt.lua | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index c982f57a2..ddd2f0186 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -55,6 +55,7 @@ local lua_util = require "lua_util" local rspamd_http = require "rspamd_http" local rspamd_logger = require "rspamd_logger" local ucl = require "ucl" +local fun = require "fun" -- Exclude checks if one of those is found local default_symbols_to_except = { @@ -172,17 +173,26 @@ local function default_conversion(task, input) return end - local spam_score = tonumber(first_message) - if not spam_score then - rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) + parser = ucl.parser() + res, err = parser:parse_string(first_message) + if not res then + rspamd_logger.errx(task, 'cannot parse JSON gpt reply: %s', err) return end - if type(reply.usage) == 'table' then - rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + reply = parser:get_object() + + if type(reply) == 'table' and reply.probability then + local spam_score = tonumber(reply.probability) + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + + return spam_score end - return spam_score + rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) + return end local function openai_gpt_check(task) @@ -238,6 +248,19 @@ local function openai_gpt_check(task) end + local url_content = "Url domains: no urls found" + if task:has_urls() then + local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 } + url_content = "Url domains: " .. table.concat(fun.totable(fun.map(function(u) + return u:get_tld() or '' + end, urls or {})), ', ') + end + + local from_or_empty = ((task:get_from('mime') or E)[1] or E) + local from_content = string.format('From: %s <%s>', from_or_empty.name, from_or_empty.addr) + lua_util.debugm(N, task, "gpt urls: %s", url_content) + lua_util.debugm(N, task, "gpt from: %s", from_content) + local body = { model = settings.model, max_tokens = settings.max_tokens, @@ -254,7 +277,11 @@ local function openai_gpt_check(task) }, { role = 'user', - content = 'From: ' .. ((task:get_from('mime') or E)[1] or E).name or '', + content = from_content, + }, + { + role = 'user', + content = url_content, }, { role = 'user', @@ -310,9 +337,9 @@ if opts then end if not settings.prompt then - settings.prompt = "You will be provided with a text of the email, " .. + settings.prompt = "You will be provided with the email message, " .. "and your task is to classify its probability to be spam, " .. - "output resulting probability as a single floating point number from 0.0 to 1.0." + "output result as JSON with 'probability' field." end if settings.type == 'openai' then |