diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-06-30 21:19:32 +0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-30 21:19:32 +0600 |
commit | d147f5466b3e36ae15ee18d2d4ff3a08ee5a152b (patch) | |
tree | bcbdd0a680c1f8441606ad92342cd81e20135a05 | |
parent | 7b3fd1688c8d6634b67acced10f770792c928a91 (diff) | |
parent | 63f0ffe317038a34fa8467f03ef9e9c43161d1bd (diff) | |
download | rspamd-d147f5466b3e36ae15ee18d2d4ff3a08ee5a152b.tar.gz rspamd-d147f5466b3e36ae15ee18d2d4ff3a08ee5a152b.zip |
Merge pull request #5032 from rspamd/vstakhov-gpt-plugin
Add GPT plugin
-rw-r--r-- | conf/modules.d/gpt.conf | 55 | ||||
-rw-r--r-- | lualib/rspamadm/classifier_test.lua | 25 | ||||
-rw-r--r-- | src/plugins/lua/gpt.lua | 348 |
3 files changed, 423 insertions, 5 deletions
diff --git a/conf/modules.d/gpt.conf b/conf/modules.d/gpt.conf new file mode 100644 index 000000000..7a2e11d40 --- /dev/null +++ b/conf/modules.d/gpt.conf @@ -0,0 +1,55 @@ +# Please don't modify this file as your changes might be overwritten with +# the next update. +# +# You can modify 'local.d/gpt.conf' to add and merge +# parameters defined inside this section +# +# You can modify 'override.d/gpt.conf' to strictly override all +# parameters defined inside this section +# +# See https://rspamd.com/doc/faq.html#what-are-the-locald-and-overrided-directories +# for details +# +# Module documentation can be found at https://rspamd.com/doc/modules/gpt.html + +gpt { + # Supported types: openai + type = "openai"; + # Your key to access the API (add this to enable this plugin) + #api_key = "xxx"; + # Model name + model = "gpt-3.5-turbo"; + # Maximum tokens to generate + max_tokens = 1000; + # Temperature for sampling + temperature = 0.7; + # Top p for sampling + top_p = 0.9; + # Timeout for requests + timeout = 10s; + # Prompt for the model (use default if not set) + #prompt = "xxx"; + # Custom condition (lua function) + #condition = "xxx"; + # Autolearn if gpt classified + #autolearn = true; + # Reply conversion (lua code) + #reply_conversion = "xxx"; + + # Default set of symbols to be excepted + #symbols_to_except = [ + # 'BAYES_SPAM', + # 'WHITELIST_SPF', + # 'WHITELIST_DKIM', + # 'WHITELIST_DMARC', + # 'FUZZY_DENIED', + #]; + + # Be sure to enable module after you specify the API key + enabled = false; + + # Include dynamic conf for the rule + .include(try=true,priority=5) "${DBDIR}/dynamic/gpt.conf" + .include(try=true,priority=1,duplicate=merge) "$LOCAL_CONFDIR/local.d/gpt.conf" + .include(try=true,priority=10) "$LOCAL_CONFDIR/override.d/gpt.conf" +}
\ No newline at end of file diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua index 21af14fc1..4148a7538 100644 --- a/lualib/rspamadm/classifier_test.lua +++ b/lualib/rspamadm/classifier_test.lua @@ -40,6 +40,14 @@ parser:option "-c --cv-fraction" :argname("<fraction>") :convert(tonumber) :default('0.7') +parser:option "--spam-symbol" + :description("Use specific spam symbol (instead of BAYES_SPAM)") + :argname("<symbol>") + :default('BAYES_SPAM') +parser:option "--ham-symbol" + :description("Use specific ham symbol (instead of BAYES_HAM)") + :argname("<symbol>") + :default('BAYES_HAM') local opts @@ -82,11 +90,12 @@ local function train_classifier(files, command) end -- Function to classify files and return results -local function classify_files(files) +local function classify_files(files, known_spam_files, known_ham_files) local fname = os.tmpname() list_to_file(files, fname) - local settings_header = '--header Settings=\"{symbols_enabled=[BAYES_SPAM, BAYES_HAM]}\"' + local settings_header = string.format('--header Settings=\"{symbols_enabled=[%s, %s]}\"', + opts.spam_symbol, opts.ham_symbol) local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s", opts.rspamc, settings_header, @@ -107,9 +116,15 @@ local function classify_files(files) local file = obj.filename local symbols = obj.symbols or {} - if symbols["BAYES_SPAM"] then + if symbols[opts.spam_symbol] then table.insert(results, { result = "spam", file = file }) - elseif symbols["BAYES_HAM"] then + if known_ham_files[file] then + rspamd_logger.message("FP: %s is classified as spam but is known ham", file) + end + elseif symbols[opts.ham_symbol] then + if known_spam_files[file] then + rspamd_logger.message("FN: %s is classified as ham but is known spam", file) + end table.insert(results, { result = "ham", file = file }) end end @@ -207,7 +222,7 @@ local function handler(args) print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns)) -- Get classification results local t = rspamd_util.get_time() - local results = classify_files(cv_files) + local results = classify_files(cv_files, known_spam_files, known_ham_files) local elapsed = rspamd_util.get_time() - t -- Evaluate results diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua new file mode 100644 index 000000000..c982f57a2 --- /dev/null +++ b/src/plugins/lua/gpt.lua @@ -0,0 +1,348 @@ +--[[ +Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +local N = "gpt" +local E = {} + +if confighelp then + rspamd_config:add_example(nil, 'gpt', + "Performs postfiltering using GPT model", + [[ +gpt { + # Supported types: openai + type = "openai"; + # Your key to access the API + api_key = "xxx"; + # Model name + model = "gpt-3.5-turbo"; + # Maximum tokens to generate + max_tokens = 1000; + # Temperature for sampling + temperature = 0.7; + # Top p for sampling + top_p = 0.9; + # Timeout for requests + timeout = 10s; + # Prompt for the model (use default if not set) + prompt = "xxx"; + # Custom condition (lua function) + condition = "xxx"; + # Autolearn if gpt classified + autolearn = true; + # Reply conversion (lua code) + reply_conversion = "xxx"; + # URL for the API + url = "https://api.openai.com/v1/chat/completions"; +} + ]]) + return +end + +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" + +-- Exclude checks if one of those is found +local default_symbols_to_except = { + 'BAYES_SPAM', -- We already know that it is a spam, so we can safely skip it, but no same logic for HAM! + 'WHITELIST_SPF', + 'WHITELIST_DKIM', + 'WHITELIST_DMARC', + 'FUZZY_DENIED', + 'REPLY', + 'BOUNCE', +} + +local settings = { + type = 'openai', + api_key = nil, + model = 'gpt-3.5-turbo', + max_tokens = 1000, + temperature = 0.7, + top_p = 0.9, + timeout = 10, + prompt = nil, + condition = nil, + autolearn = false, + url = 'https://api.openai.com/v1/chat/completions', + symbols_to_except = default_symbols_to_except, +} + +local function default_condition(task) + -- Check result + -- 1) Skip passthrough + -- 2) Skip already decided as spam + -- 3) Skip already decided as ham + local result = task:get_metric_result() + if result then + if result.passthrough then + return false, 'passthrough' + end + local score = result.score + local action = result.action + + if action == 'reject' and result.npositive > 1 then + return true, 'already decided as spam' + end + + if action == 'no action' and score < 0 then + return true, 'negative score, already decided as ham' + end + end + -- We also exclude some symbols + for _, s in ipairs(settings.symbols_to_except) do + if task:has_symbol(s) then + return false, 'skip as "' .. s .. '" is found' + end + end + + -- Check if we have text at all + local mp = task:get_parts() or {} + local sel_part + for _, mime_part in ipairs(mp) do + if mime_part:is_text() then + local part = mime_part:get_text() + if part:is_html() then + -- We prefer html content + sel_part = part + elseif not sel_part then + sel_part = part + end + end + end + + if not sel_part then + return false, 'no text part found' + end + + -- Check limits and size sanity + local nwords = sel_part:get_words_count() + + if nwords < 5 then + return false, 'less than 5 words' + end + + if nwords > settings.max_tokens then + -- We need to truncate words (sometimes get_words_count returns a different number comparing to `get_words`) + local words = sel_part:get_words('norm') + nwords = #words + if nwords > settings.max_tokens then + return true, table.concat(words, ' ', 1, settings.max_tokens) + end + end + return true, sel_part:get_content_oneline() +end + +local function default_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then + rspamd_logger.errx(task, 'no choices in reply') + return + end + + local first_message = reply.choices[1].message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + + local spam_score = tonumber(first_message) + if not spam_score then + rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) + return + end + + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + + return spam_score +end + +local function openai_gpt_check(task) + local ret, content = settings.condition(task) + + if not ret then + rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) + return + end + + if not content then + lua_util.debugm(N, task, "no content to send to gpt classification") + return + end + + lua_util.debugm(N, task, "sending content to gpt: %s", content) + + local upstream + + local function on_reply(err, code, body) + + if err then + rspamd_logger.errx(task, 'request failed: %s', err) + upstream:fail() + return + end + + upstream:ok() + lua_util.debugm(N, task, "got reply: %s", body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end + + local reply = settings.reply_conversion(task, body) + if not reply then + return + end + + if reply > 0.75 then + task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) + if settings.autolearn then + task:set_flag("learn_spam") + end + elseif reply < 0.25 then + task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) + if settings.autolearn then + task:set_flag("learn_ham") + end + else + lua_util.debugm(N, task, "uncertain result: %s", reply) + end + + end + + local body = { + model = settings.model, + max_tokens = settings.max_tokens, + temperature = settings.temperature, + top_p = settings.top_p, + messages = { + { + role = 'system', + content = settings.prompt + }, + { + role = 'user', + content = 'Subject: ' .. task:get_subject() or '', + }, + { + role = 'user', + content = 'From: ' .. ((task:get_from('mime') or E)[1] or E).name or '', + }, + { + role = 'user', + content = content + } + } + } + + upstream = settings.upstreams:get_upstream_round_robin() + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = on_reply, + headers = { + ['Authorization'] = 'Bearer ' .. settings.api_key, + }, + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } + + rspamd_http.request(http_params) +end + +local function gpt_check(task) + return settings.specific_check(task) +end + +local opts = rspamd_config:get_all_opt('gpt') +if opts then + settings = lua_util.override_defaults(settings, opts) + + if not settings.api_key then + rspamd_logger.warnx(rspamd_config, 'no api_key is specified, disabling module') + lua_util.disable_module(N, "config") + + return + end + if settings.condition then + settings.condition = load(settings.condition)() + else + settings.condition = default_condition + end + + if settings.reply_conversion then + settings.reply_conversion = load(settings.reply_conversion)() + else + settings.reply_conversion = default_conversion + end + + if not settings.prompt then + settings.prompt = "You will be provided with a text of the email, " .. + "and your task is to classify its probability to be spam, " .. + "output resulting probability as a single floating point number from 0.0 to 1.0." + end + + if settings.type == 'openai' then + settings.specific_check = openai_gpt_check + else + rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type) + lua_util.disable_module(N, "config") + return + end + + settings.upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), settings.url) + + local id = rspamd_config:register_symbol({ + name = 'GPT_CHECK', + type = 'postfilter', + callback = gpt_check, + priority = lua_util.symbols_priorities.medium, + augmentations = { string.format("timeout=%f", settings.timeout or 0.0) }, + }) + + rspamd_config:register_symbol({ + name = 'GPT_SPAM', + type = 'virtual', + parent = id, + score = 5.0, + }) + rspamd_config:register_symbol({ + name = 'GPT_HAM', + type = 'virtual', + parent = id, + score = -2.0, + }) +end
\ No newline at end of file |