diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-06-27 15:39:09 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-06-27 15:39:09 +0100 |
commit | 932a8a1cf4671232dcb70c9931b1d7d75e8220b9 (patch) | |
tree | 28aa965a1b297d20351b3327c70a17eb0d94024c /src | |
parent | 7b3fd1688c8d6634b67acced10f770792c928a91 (diff) | |
download | rspamd-932a8a1cf4671232dcb70c9931b1d7d75e8220b9.tar.gz rspamd-932a8a1cf4671232dcb70c9931b1d7d75e8220b9.zip |
[Project] Add GPT plugin
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/gpt.lua | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua new file mode 100644 index 000000000..96c632d4c --- /dev/null +++ b/src/plugins/lua/gpt.lua @@ -0,0 +1,256 @@ +--[[ +Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +local N = "gpt" + +if confighelp then + rspamd_config:add_example(nil, 'greylist', + "Performs adaptive greylisting using Redis", + [[ +gpt { + # Supported types: openai + type = "openai"; + # Your key to access the API + api_key = "xxx"; + # Model name + model = "gpt-3.5-turbo"; + # Maximum tokens to generate + max_tokens = 100; + # Temperature for sampling + temperature = 0.7; + # Top p for sampling + top_p = 0.9; + # Timeout for requests + timeout = 10s; + # Prompt for the model (use default if not set) + prompt = "xxx"; + # Custom condition (lua function) + condition = "xxx"; + # Autolearn if gpt classified + autolearn = true; + # Reply conversion (lua code) + reply_conversion = "xxx"; +} + ]]) + return +end + +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" + +local settings = { + type = 'openai', + api_key = nil, + model = 'gpt-3.5-turbo', + max_tokens = 100, + temperature = 0.7, + top_p = 0.9, + timeout = 10, + prompt = nil, + condition = nil, + autolearn = false, +} + +local function default_condition(task) + return true +end + +local function default_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then + rspamd_logger.errx(task, 'no choices in reply') + return + end + + local first_message = reply.choices[1].message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + + local spam_score = tonumber(first_message) + if not spam_score then + rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) + return + end + + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + + return spam_score +end + +local function openai_gpt_check(task) + if not settings.condition(task) then + lua_util.debugm(N, task, "skip checking gpt as the condition is not met") + return + end + + local function on_reply(err, code, body) + + if err then + rspamd_logger.errx(task, 'request failed: %s', err) + return + end + + lua_util.debugm(N, task, "got reply: %s", body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end + + local reply = settings.reply_conversion(task, body) + if not reply then + return + end + + if reply > 0.75 then + task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) + elseif reply < 0.25 then + task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) + else + lua_util.debugm(N, task, "uncertain result: %s", reply) + end + + -- TODO: add autolearn here + end + + local mp = task:get_parts() or {} + local content + for _, mime_part in ipairs(mp) do + if mime_part:is_text() then + local part = mime_part:get_text() + if part:is_html() then + -- We prefer html content + content = part:get_content_oneline() + elseif not content then + content = part:get_content_oneline() + end + end + end + + if not content then + lua_util.debugm(N, task, "no content to send to gpt classification") + end + + local body = { + model = settings.model, + max_tokens = settings.max_tokens, + temperature = settings.temperature, + top_p = settings.top_p, + messages = { + { + role = 'system', + content = settings.prompt + }, + { + role = 'user', + content = content + } + } + } + local http_params = { + url = 'https://api.openai.com/v1/chat/completions', + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = on_reply, + headers = { + ['Authorization'] = 'Bearer ' .. settings.api_key, + }, + body = ucl.to_format(body, 'json-compact', true), + task = task, + } + + rspamd_http.request(http_params) +end + +local function gpt_check(task) + return settings.specific_check(task) +end + +local opts = rspamd_config:get_all_opt('gpt') +if opts then + settings = lua_util.override_defaults(settings, opts) + + if not settings.api_key then + rspamd_logger.warnx(rspamd_config, 'no api_key is specified, disabling module') + lua_util.disable_module(N, "config") + + return + end + if settings.condition then + settings.condition = load(settings.condition)() + else + settings.condition = default_condition + end + + if settings.reply_conversion then + settings.reply_conversion = load(settings.reply_conversion)() + else + settings.reply_conversion = default_conversion + end + + if not settings.prompt then + settings.prompt = "You will be provided with a text of the email, " .. + "and your task is to classify its probability to be spam, " .. + "output resulting probability as a single floating point number from 0.0 to 1.0." + end + + if settings.type == 'openai' then + settings.specific_check = openai_gpt_check + else + rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type) + lua_util.disable_module(N, "config") + return + end + + local id = rspamd_config:register_symbol({ + name = 'GPT_CHECK', + type = 'postfilter', + callback = gpt_check, + priority = lua_util.symbols_priorities.medium, + augmentations = { string.format("timeout=%f", settings.timeout or 0.0) }, + }) + + rspamd_config:register_symbol({ + name = 'GPT_SPAM', + type = 'virtual', + parent = id, + score = 5.0, + }) + rspamd_config:register_symbol({ + name = 'GPT_HAM', + type = 'virtual', + parent = id, + score = -2.0, + }) +end
\ No newline at end of file |