aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-06-30 21:19:32 +0600
committerGitHub <noreply@github.com>2024-06-30 21:19:32 +0600
commitd147f5466b3e36ae15ee18d2d4ff3a08ee5a152b (patch)
treebcbdd0a680c1f8441606ad92342cd81e20135a05
parent7b3fd1688c8d6634b67acced10f770792c928a91 (diff)
parent63f0ffe317038a34fa8467f03ef9e9c43161d1bd (diff)
downloadrspamd-d147f5466b3e36ae15ee18d2d4ff3a08ee5a152b.tar.gz
rspamd-d147f5466b3e36ae15ee18d2d4ff3a08ee5a152b.zip
Merge pull request #5032 from rspamd/vstakhov-gpt-plugin
Add GPT plugin
-rw-r--r--conf/modules.d/gpt.conf55
-rw-r--r--lualib/rspamadm/classifier_test.lua25
-rw-r--r--src/plugins/lua/gpt.lua348
3 files changed, 423 insertions, 5 deletions
diff --git a/conf/modules.d/gpt.conf b/conf/modules.d/gpt.conf
new file mode 100644
index 000000000..7a2e11d40
--- /dev/null
+++ b/conf/modules.d/gpt.conf
@@ -0,0 +1,55 @@
+# Please don't modify this file as your changes might be overwritten with
+# the next update.
+#
+# You can modify 'local.d/gpt.conf' to add and merge
+# parameters defined inside this section
+#
+# You can modify 'override.d/gpt.conf' to strictly override all
+# parameters defined inside this section
+#
+# See https://rspamd.com/doc/faq.html#what-are-the-locald-and-overrided-directories
+# for details
+#
+# Module documentation can be found at https://rspamd.com/doc/modules/gpt.html
+
+gpt {
+ # Supported types: openai
+ type = "openai";
+ # Your key to access the API (add this to enable this plugin)
+ #api_key = "xxx";
+ # Model name
+ model = "gpt-3.5-turbo";
+ # Maximum tokens to generate
+ max_tokens = 1000;
+ # Temperature for sampling
+ temperature = 0.7;
+ # Top p for sampling
+ top_p = 0.9;
+ # Timeout for requests
+ timeout = 10s;
+ # Prompt for the model (use default if not set)
+ #prompt = "xxx";
+ # Custom condition (lua function)
+ #condition = "xxx";
+ # Autolearn if gpt classified
+ #autolearn = true;
+ # Reply conversion (lua code)
+ #reply_conversion = "xxx";
+
+ # Default set of symbols to be excepted
+ #symbols_to_except = [
+ # 'BAYES_SPAM',
+ # 'WHITELIST_SPF',
+ # 'WHITELIST_DKIM',
+ # 'WHITELIST_DMARC',
+ # 'FUZZY_DENIED',
+ #];
+
+ # Be sure to enable module after you specify the API key
+ enabled = false;
+
+ # Include dynamic conf for the rule
+ .include(try=true,priority=5) "${DBDIR}/dynamic/gpt.conf"
+ .include(try=true,priority=1,duplicate=merge) "$LOCAL_CONFDIR/local.d/gpt.conf"
+ .include(try=true,priority=10) "$LOCAL_CONFDIR/override.d/gpt.conf"
+} \ No newline at end of file
diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua
index 21af14fc1..4148a7538 100644
--- a/lualib/rspamadm/classifier_test.lua
+++ b/lualib/rspamadm/classifier_test.lua
@@ -40,6 +40,14 @@ parser:option "-c --cv-fraction"
:argname("<fraction>")
:convert(tonumber)
:default('0.7')
+parser:option "--spam-symbol"
+ :description("Use specific spam symbol (instead of BAYES_SPAM)")
+ :argname("<symbol>")
+ :default('BAYES_SPAM')
+parser:option "--ham-symbol"
+ :description("Use specific ham symbol (instead of BAYES_HAM)")
+ :argname("<symbol>")
+ :default('BAYES_HAM')
local opts
@@ -82,11 +90,12 @@ local function train_classifier(files, command)
end
-- Function to classify files and return results
-local function classify_files(files)
+local function classify_files(files, known_spam_files, known_ham_files)
local fname = os.tmpname()
list_to_file(files, fname)
- local settings_header = '--header Settings=\"{symbols_enabled=[BAYES_SPAM, BAYES_HAM]}\"'
+ local settings_header = string.format('--header Settings=\"{symbols_enabled=[%s, %s]}\"',
+ opts.spam_symbol, opts.ham_symbol)
local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s",
opts.rspamc,
settings_header,
@@ -107,9 +116,15 @@ local function classify_files(files)
local file = obj.filename
local symbols = obj.symbols or {}
- if symbols["BAYES_SPAM"] then
+ if symbols[opts.spam_symbol] then
table.insert(results, { result = "spam", file = file })
- elseif symbols["BAYES_HAM"] then
+ if known_ham_files[file] then
+ rspamd_logger.message("FP: %s is classified as spam but is known ham", file)
+ end
+ elseif symbols[opts.ham_symbol] then
+ if known_spam_files[file] then
+ rspamd_logger.message("FN: %s is classified as ham but is known spam", file)
+ end
table.insert(results, { result = "ham", file = file })
end
end
@@ -207,7 +222,7 @@ local function handler(args)
print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
-- Get classification results
local t = rspamd_util.get_time()
- local results = classify_files(cv_files)
+ local results = classify_files(cv_files, known_spam_files, known_ham_files)
local elapsed = rspamd_util.get_time() - t
-- Evaluate results
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
new file mode 100644
index 000000000..c982f57a2
--- /dev/null
+++ b/src/plugins/lua/gpt.lua
@@ -0,0 +1,348 @@
+--[[
+Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]] --
+
+local N = "gpt"
+local E = {}
+
+if confighelp then
+ rspamd_config:add_example(nil, 'gpt',
+ "Performs postfiltering using GPT model",
+ [[
+gpt {
+ # Supported types: openai
+ type = "openai";
+ # Your key to access the API
+ api_key = "xxx";
+ # Model name
+ model = "gpt-3.5-turbo";
+ # Maximum tokens to generate
+ max_tokens = 1000;
+ # Temperature for sampling
+ temperature = 0.7;
+ # Top p for sampling
+ top_p = 0.9;
+ # Timeout for requests
+ timeout = 10s;
+ # Prompt for the model (use default if not set)
+ prompt = "xxx";
+ # Custom condition (lua function)
+ condition = "xxx";
+ # Autolearn if gpt classified
+ autolearn = true;
+ # Reply conversion (lua code)
+ reply_conversion = "xxx";
+ # URL for the API
+ url = "https://api.openai.com/v1/chat/completions";
+}
+ ]])
+ return
+end
+
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+
+-- Exclude checks if one of those is found
+local default_symbols_to_except = {
+ 'BAYES_SPAM', -- We already know that it is a spam, so we can safely skip it, but no same logic for HAM!
+ 'WHITELIST_SPF',
+ 'WHITELIST_DKIM',
+ 'WHITELIST_DMARC',
+ 'FUZZY_DENIED',
+ 'REPLY',
+ 'BOUNCE',
+}
+
+local settings = {
+ type = 'openai',
+ api_key = nil,
+ model = 'gpt-3.5-turbo',
+ max_tokens = 1000,
+ temperature = 0.7,
+ top_p = 0.9,
+ timeout = 10,
+ prompt = nil,
+ condition = nil,
+ autolearn = false,
+ url = 'https://api.openai.com/v1/chat/completions',
+ symbols_to_except = default_symbols_to_except,
+}
+
+local function default_condition(task)
+ -- Check result
+ -- 1) Skip passthrough
+ -- 2) Skip already decided as spam
+ -- 3) Skip already decided as ham
+ local result = task:get_metric_result()
+ if result then
+ if result.passthrough then
+ return false, 'passthrough'
+ end
+ local score = result.score
+ local action = result.action
+
+ if action == 'reject' and result.npositive > 1 then
+ return true, 'already decided as spam'
+ end
+
+ if action == 'no action' and score < 0 then
+ return true, 'negative score, already decided as ham'
+ end
+ end
+ -- We also exclude some symbols
+ for _, s in ipairs(settings.symbols_to_except) do
+ if task:has_symbol(s) then
+ return false, 'skip as "' .. s .. '" is found'
+ end
+ end
+
+ -- Check if we have text at all
+ local mp = task:get_parts() or {}
+ local sel_part
+ for _, mime_part in ipairs(mp) do
+ if mime_part:is_text() then
+ local part = mime_part:get_text()
+ if part:is_html() then
+ -- We prefer html content
+ sel_part = part
+ elseif not sel_part then
+ sel_part = part
+ end
+ end
+ end
+
+ if not sel_part then
+ return false, 'no text part found'
+ end
+
+ -- Check limits and size sanity
+ local nwords = sel_part:get_words_count()
+
+ if nwords < 5 then
+ return false, 'less than 5 words'
+ end
+
+ if nwords > settings.max_tokens then
+ -- We need to truncate words (sometimes get_words_count returns a different number comparing to `get_words`)
+ local words = sel_part:get_words('norm')
+ nwords = #words
+ if nwords > settings.max_tokens then
+ return true, table.concat(words, ' ', 1, settings.max_tokens)
+ end
+ end
+ return true, sel_part:get_content_oneline()
+end
+
+local function default_conversion(task, input)
+ local parser = ucl.parser()
+ local res, err = parser:parse_string(input)
+ if not res then
+ rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+ return
+ end
+ local reply = parser:get_object()
+ if not reply then
+ rspamd_logger.errx(task, 'cannot get object from reply')
+ return
+ end
+
+ if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then
+ rspamd_logger.errx(task, 'no choices in reply')
+ return
+ end
+
+ local first_message = reply.choices[1].message.content
+
+ if not first_message then
+ rspamd_logger.errx(task, 'no content in the first message')
+ return
+ end
+
+ local spam_score = tonumber(first_message)
+ if not spam_score then
+ rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+ return
+ end
+
+ if type(reply.usage) == 'table' then
+ rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+ end
+
+ return spam_score
+end
+
+local function openai_gpt_check(task)
+ local ret, content = settings.condition(task)
+
+ if not ret then
+ rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
+ return
+ end
+
+ if not content then
+ lua_util.debugm(N, task, "no content to send to gpt classification")
+ return
+ end
+
+ lua_util.debugm(N, task, "sending content to gpt: %s", content)
+
+ local upstream
+
+ local function on_reply(err, code, body)
+
+ if err then
+ rspamd_logger.errx(task, 'request failed: %s', err)
+ upstream:fail()
+ return
+ end
+
+ upstream:ok()
+ lua_util.debugm(N, task, "got reply: %s", body)
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'bad reply: %s', body)
+ return
+ end
+
+ local reply = settings.reply_conversion(task, body)
+ if not reply then
+ return
+ end
+
+ if reply > 0.75 then
+ task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
+ if settings.autolearn then
+ task:set_flag("learn_spam")
+ end
+ elseif reply < 0.25 then
+ task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
+ if settings.autolearn then
+ task:set_flag("learn_ham")
+ end
+ else
+ lua_util.debugm(N, task, "uncertain result: %s", reply)
+ end
+
+ end
+
+ local body = {
+ model = settings.model,
+ max_tokens = settings.max_tokens,
+ temperature = settings.temperature,
+ top_p = settings.top_p,
+ messages = {
+ {
+ role = 'system',
+ content = settings.prompt
+ },
+ {
+ role = 'user',
+ content = 'Subject: ' .. task:get_subject() or '',
+ },
+ {
+ role = 'user',
+ content = 'From: ' .. ((task:get_from('mime') or E)[1] or E).name or '',
+ },
+ {
+ role = 'user',
+ content = content
+ }
+ }
+ }
+
+ upstream = settings.upstreams:get_upstream_round_robin()
+ local http_params = {
+ url = settings.url,
+ mime_type = 'application/json',
+ timeout = settings.timeout,
+ log_obj = task,
+ callback = on_reply,
+ headers = {
+ ['Authorization'] = 'Bearer ' .. settings.api_key,
+ },
+ keepalive = true,
+ body = ucl.to_format(body, 'json-compact', true),
+ task = task,
+ upstream = upstream,
+ use_gzip = true,
+ }
+
+ rspamd_http.request(http_params)
+end
+
+local function gpt_check(task)
+ return settings.specific_check(task)
+end
+
+local opts = rspamd_config:get_all_opt('gpt')
+if opts then
+ settings = lua_util.override_defaults(settings, opts)
+
+ if not settings.api_key then
+ rspamd_logger.warnx(rspamd_config, 'no api_key is specified, disabling module')
+ lua_util.disable_module(N, "config")
+
+ return
+ end
+ if settings.condition then
+ settings.condition = load(settings.condition)()
+ else
+ settings.condition = default_condition
+ end
+
+ if settings.reply_conversion then
+ settings.reply_conversion = load(settings.reply_conversion)()
+ else
+ settings.reply_conversion = default_conversion
+ end
+
+ if not settings.prompt then
+ settings.prompt = "You will be provided with a text of the email, " ..
+ "and your task is to classify its probability to be spam, " ..
+ "output resulting probability as a single floating point number from 0.0 to 1.0."
+ end
+
+ if settings.type == 'openai' then
+ settings.specific_check = openai_gpt_check
+ else
+ rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
+ lua_util.disable_module(N, "config")
+ return
+ end
+
+ settings.upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), settings.url)
+
+ local id = rspamd_config:register_symbol({
+ name = 'GPT_CHECK',
+ type = 'postfilter',
+ callback = gpt_check,
+ priority = lua_util.symbols_priorities.medium,
+ augmentations = { string.format("timeout=%f", settings.timeout or 0.0) },
+ })
+
+ rspamd_config:register_symbol({
+ name = 'GPT_SPAM',
+ type = 'virtual',
+ parent = id,
+ score = 5.0,
+ })
+ rspamd_config:register_symbol({
+ name = 'GPT_HAM',
+ type = 'virtual',
+ parent = id,
+ score = -2.0,
+ })
+end \ No newline at end of file