]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add GPT plugin
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 27 Jun 2024 14:39:09 +0000 (15:39 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 27 Jun 2024 14:39:09 +0000 (15:39 +0100)
conf/modules.d/gpt.conf [new file with mode: 0644]
src/plugins/lua/gpt.lua [new file with mode: 0644]

diff --git a/conf/modules.d/gpt.conf b/conf/modules.d/gpt.conf
new file mode 100644 (file)
index 0000000..1dd5405
--- /dev/null
@@ -0,0 +1,43 @@
+# Please don't modify this file as your changes might be overwritten with
+# the next update.
+#
+# You can modify 'local.d/gpt.conf' to add and merge
+# parameters defined inside this section
+#
+# You can modify 'override.d/gpt.conf' to strictly override all
+# parameters defined inside this section
+#
+# See https://rspamd.com/doc/faq.html#what-are-the-locald-and-overrided-directories
+# for details
+#
+# Module documentation can be found at  https://rspamd.com/doc/modules/gpt.html
+
+gpt {
+  # Supported types: openai
+  type = "openai";
+  # Your key to access the API (add this to enable this plugin)
+  #api_key = "xxx";
+  # Model name
+  model = "gpt-3.5-turbo";
+  # Maximum tokens to generate
+  max_tokens = 100;
+  # Temperature for sampling
+  temperature = 0.7;
+  # Top p for sampling
+  top_p = 0.9;
+  # Timeout for requests
+  timeout = 10s;
+  # Prompt for the model (use default if not set)
+  #prompt = "xxx";
+  # Custom condition (lua function)
+  #condition = "xxx";
+  # Autolearn if gpt classified
+  #autolearn = true;
+  # Reply conversion (lua code)
+  #reply_conversion = "xxx";
+
+  # Include dynamic conf for the rule
+  .include(try=true,priority=5) "${DBDIR}/dynamic/gpt.conf"
+  .include(try=true,priority=1,duplicate=merge) "$LOCAL_CONFDIR/local.d/gpt.conf"
+  .include(try=true,priority=10) "$LOCAL_CONFDIR/override.d/gpt.conf"
+}
\ No newline at end of file
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
new file mode 100644 (file)
index 0000000..96c632d
--- /dev/null
@@ -0,0 +1,256 @@
+--[[
+Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]] --
+
+local N = "gpt"
+
+if confighelp then
+  rspamd_config:add_example(nil, 'greylist',
+      "Performs adaptive greylisting using Redis",
+      [[
+gpt {
+  # Supported types: openai
+  type = "openai";
+  # Your key to access the API
+  api_key = "xxx";
+  # Model name
+  model = "gpt-3.5-turbo";
+  # Maximum tokens to generate
+  max_tokens = 100;
+  # Temperature for sampling
+  temperature = 0.7;
+  # Top p for sampling
+  top_p = 0.9;
+  # Timeout for requests
+  timeout = 10s;
+  # Prompt for the model (use default if not set)
+  prompt = "xxx";
+  # Custom condition (lua function)
+  condition = "xxx";
+  # Autolearn if gpt classified
+  autolearn = true;
+  # Reply conversion (lua code)
+  reply_conversion = "xxx";
+}
+  ]])
+  return
+end
+
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+
+local settings = {
+  type = 'openai',
+  api_key = nil,
+  model = 'gpt-3.5-turbo',
+  max_tokens = 100,
+  temperature = 0.7,
+  top_p = 0.9,
+  timeout = 10,
+  prompt = nil,
+  condition = nil,
+  autolearn = false,
+}
+
+local function default_condition(task)
+  return true
+end
+
+local function default_conversion(task, input)
+  local parser = ucl.parser()
+  local res, err = parser:parse_string(input)
+  if not res then
+    rspamd_logger.errx(task, 'cannot parse reply: %s', err)
+    return
+  end
+  local reply = parser:get_object()
+  if not reply then
+    rspamd_logger.errx(task, 'cannot get object from reply')
+    return
+  end
+
+  if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then
+    rspamd_logger.errx(task, 'no choices in reply')
+    return
+  end
+
+  local first_message = reply.choices[1].message.content
+
+  if not first_message then
+    rspamd_logger.errx(task, 'no content in the first message')
+    return
+  end
+
+  local spam_score = tonumber(first_message)
+  if not spam_score then
+    rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message)
+    return
+  end
+
+  if type(reply.usage) == 'table' then
+    rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens)
+  end
+
+  return spam_score
+end
+
+local function openai_gpt_check(task)
+  if not settings.condition(task) then
+    lua_util.debugm(N, task, "skip checking gpt as the condition is not met")
+    return
+  end
+
+  local function on_reply(err, code, body)
+
+    if err then
+      rspamd_logger.errx(task, 'request failed: %s', err)
+      return
+    end
+
+    lua_util.debugm(N, task, "got reply: %s", body)
+    if code ~= 200 then
+      rspamd_logger.errx(task, 'bad reply: %s', body)
+      return
+    end
+
+    local reply = settings.reply_conversion(task, body)
+    if not reply then
+      return
+    end
+
+    if reply > 0.75 then
+      task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
+    elseif reply < 0.25 then
+      task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
+    else
+      lua_util.debugm(N, task, "uncertain result: %s", reply)
+    end
+
+    -- TODO: add autolearn here
+  end
+
+  local mp = task:get_parts() or {}
+  local content
+  for _, mime_part in ipairs(mp) do
+    if mime_part:is_text() then
+      local part = mime_part:get_text()
+      if part:is_html() then
+        -- We prefer html content
+        content = part:get_content_oneline()
+      elseif not content then
+        content = part:get_content_oneline()
+      end
+    end
+  end
+
+  if not content then
+    lua_util.debugm(N, task, "no content to send to gpt classification")
+  end
+
+  local body = {
+    model = settings.model,
+    max_tokens = settings.max_tokens,
+    temperature = settings.temperature,
+    top_p = settings.top_p,
+    messages = {
+      {
+        role = 'system',
+        content = settings.prompt
+      },
+      {
+        role = 'user',
+        content = content
+      }
+    }
+  }
+  local http_params = {
+    url = 'https://api.openai.com/v1/chat/completions',
+    mime_type = 'application/json',
+    timeout = settings.timeout,
+    log_obj = task,
+    callback = on_reply,
+    headers = {
+      ['Authorization'] = 'Bearer ' .. settings.api_key,
+    },
+    body = ucl.to_format(body, 'json-compact', true),
+    task = task,
+  }
+
+  rspamd_http.request(http_params)
+end
+
+local function gpt_check(task)
+  return settings.specific_check(task)
+end
+
+local opts = rspamd_config:get_all_opt('gpt')
+if opts then
+  settings = lua_util.override_defaults(settings, opts)
+
+  if not settings.api_key then
+    rspamd_logger.warnx(rspamd_config, 'no api_key is specified, disabling module')
+    lua_util.disable_module(N, "config")
+
+    return
+  end
+  if settings.condition then
+    settings.condition = load(settings.condition)()
+  else
+    settings.condition = default_condition
+  end
+
+  if settings.reply_conversion then
+    settings.reply_conversion = load(settings.reply_conversion)()
+  else
+    settings.reply_conversion = default_conversion
+  end
+
+  if not settings.prompt then
+    settings.prompt = "You will be provided with a text of the email, " ..
+        "and your task is to classify its probability to be spam, " ..
+        "output resulting probability as a single floating point number from 0.0 to 1.0."
+  end
+
+  if settings.type == 'openai' then
+    settings.specific_check = openai_gpt_check
+  else
+    rspamd_logger.warnx(rspamd_config, 'unsupported gpt type: %s', settings.type)
+    lua_util.disable_module(N, "config")
+    return
+  end
+
+  local id = rspamd_config:register_symbol({
+    name = 'GPT_CHECK',
+    type = 'postfilter',
+    callback = gpt_check,
+    priority = lua_util.symbols_priorities.medium,
+    augmentations = { string.format("timeout=%f", settings.timeout or 0.0) },
+  })
+
+  rspamd_config:register_symbol({
+    name = 'GPT_SPAM',
+    type = 'virtual',
+    parent = id,
+    score = 5.0,
+  })
+  rspamd_config:register_symbol({
+    name = 'GPT_HAM',
+    type = 'virtual',
+    parent = id,
+    score = -2.0,
+  })
+end
\ No newline at end of file