aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorAndrew Lewis <nerf@judo.za.org>2025-02-27 22:23:32 +0200
committerAndrew Lewis <nerf@judo.za.org>2025-03-24 13:54:51 +0200
commit9f0a5321c55138f169d8545c84b21791b0c71788 (patch)
treece70c3ad221ec20b4dc2cdf8f5dd9872d5d86456 /src
parent42d4e3d049b28e3c0e03a63b6eb468b58c0b92cf (diff)
downloadrspamd-9f0a5321c55138f169d8545c84b21791b0c71788.tar.gz
rspamd-9f0a5321c55138f169d8545c84b21791b0c71788.zip
[Feature] Plugin to integrate with Contextal platform
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/contextal.lua302
1 files changed, 302 insertions, 0 deletions
diff --git a/src/plugins/lua/contextal.lua b/src/plugins/lua/contextal.lua
new file mode 100644
index 000000000..341b7a125
--- /dev/null
+++ b/src/plugins/lua/contextal.lua
@@ -0,0 +1,302 @@
+--[[
+Copyright (c) 2025, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local E = {}
+local N = 'contextal'
+
+if confighelp then
+ return
+end
+
+local opts = rspamd_config:get_all_opt(N)
+if not opts then
+ return
+end
+
+local lua_redis = require "lua_redis"
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
+
+local contextal_actions = {
+ 'ALERT',
+ 'ALLOW',
+ 'BLOCK',
+ 'QUARANTINE',
+ 'SPAM',
+}
+
+local settings = {
+ action_symbol_prefix = 'CONTEXTAL_ACTION',
+ base_url = 'http://localhost:8080',
+ cache_ttl = 3600,
+ custom_actions = {},
+ http_timeout = 2,
+ key_prefix = 'CXAL',
+ request_ttl = 4,
+ submission_symbol = 'CONTEXTAL_SUBMIT',
+}
+
+local static_boundary = rspamd_util.random_hex(32)
+
+local function cache_key(task)
+ return string.format('%s_%s', settings.key_prefix, task:get_digest())
+end
+
+local function process_actions(task, obj, is_cached)
+ for _, match in ipairs((obj[1] or E).actions) do
+ local act = match.action
+ local scenario = match.scenario
+ if not (act and scenario) then
+ rspamd_logger.err(task, 'bad result: %s', match)
+ elseif contextal_actions[act] then
+ task:insert_result(settings.action_symbol_prefix .. '_' .. act, 1.0, scenario)
+ else
+ rspamd_logger.err(task, 'unknown action: %s', act)
+ end
+ end
+
+ if not redis_params or is_cached then return end
+
+ local cache_obj
+ if (obj[1] or E).actions then
+ cache_obj = {[1] = {["actions"] = obj[1].actions}}
+ elseif (obj[1] or E).work_id then
+ cache_obj = {[1] = {["work_id"] = obj[1].work_id}}
+ else
+ rspamd_logger.err(task, 'bad result: %s', obj)
+ return
+ end
+
+ local key = cache_key(task)
+ local ret = lua_redis.redis_make_request(task,
+ redis_params, -- connect params
+ key, -- hash key
+ true, -- is write
+ redis_get_cb, --callback
+ 'SET', -- command
+ { key, ucl.to_format(cache_obj, 'json-compact') } -- arguments
+ )
+
+ if not ret then
+ rspamd_logger.err(task, 'cannot make redis request to cache result')
+ return
+ end
+end
+
+local function process_cached(task, txt)
+ local parser = ucl.parser()
+ local _, err = parser:parse_string(txt)
+ if err then
+ rspamd_logger.err(task, 'cannot parse JSON (cached): %s', err)
+ return
+ end
+ local obj = parser:get_object()
+ if (obj[1] or E).actions then
+ task:disable_symbol(settings.action_symbol_prefix)
+ return process_actions(task, obj, true)
+ elseif (obj[1] or E).work_id then
+ task:get_mempool():set_variable('contextal_work_id', obj.work_id)
+ else
+ rspamd_logger.err(task, 'bad result (cached): %s', obj)
+ end
+end
+
+local function submit(task)
+
+ local function http_callback(err, code, body, hdrs)
+ if err then
+ rspamd_logger.err(task, 'http error: %s', err)
+ return
+ end
+ if code ~= 201 then
+ rspamd_logger.err(task, 'bad http code: %s', code)
+ return
+ end
+ local parser = ucl.parser()
+ local _, err = parser:parse_string(body)
+ if err then
+ rspamd_logger.err(task, 'cannot parse JSON: %s', err)
+ return
+ end
+ local obj = parser:get_object()
+ local work_id = obj.work_id
+ if work_id then
+ task:get_mempool():set_variable('contextal_work_id', work_id)
+ end
+ task:insert_result(settings.submission_symbol, 1.0,
+ string.format('work_id=%s', work_id or 'nil'))
+ end
+
+ local req = {
+ object_data = {['data'] = task:get_content()},
+ }
+ if settings.request_ttl then
+ req.ttl = {['data'] = tostring(settings.request_ttl)}
+ end
+ if settings.max_recursion then
+ req.maxrec = {['data'] = tostring(settings.max_recursion)}
+ end
+ rspamd_http.request({
+ task = task,
+ url = settings.submit_url,
+ body = lua_util.table_to_multipart_body(req, static_boundary),
+ callback = http_callback,
+ headers = {
+ ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary)
+ },
+ timeout = settings.http_timeout,
+ gzip = settings.gzip,
+ keepalive = settings.keepalive,
+ no_ssl_verify = settings.no_ssl_verify,
+ })
+end
+
+local function submit_cb(task)
+ if redis_params then
+
+ local function redis_get_cb(err, data)
+ if err then
+ rspamd_logger.err(task, 'error querying redis: %s', err)
+ return
+ end
+ if type(data) == 'userdata' then
+ return submit(task)
+ end
+ process_cached(task, data)
+ end
+
+ local key = cache_key(task)
+ local ret = lua_redis.redis_make_request(task,
+ redis_params, -- connect params
+ key, -- hash key
+ false, -- is write
+ redis_get_cb, --callback
+ 'GET', -- command
+ { key } -- arguments
+ )
+
+ if not ret then
+ rspamd_logger.err(task, 'cannot make redis request to check results')
+ return
+ end
+
+ else
+ return submit(task)
+ end
+end
+
+local function action_cb(task)
+ local work_id = task:get_mempool():get_variable('contextal_work_id', 'string')
+ if not work_id then
+ rspamd_logger.err(task, 'no work id found in mempool')
+ return
+ end
+
+ local function http_callback(err, code, body, hdrs)
+ if err then
+ rspamd_logger.err(task, 'http error: %s', err)
+ return
+ end
+ if code ~= 200 then
+ rspamd_logger.err(task, 'bad http code: %s', code)
+ return
+ end
+ local parser = ucl.parser()
+ local _, err = parser:parse_string(body)
+ if err then
+ rspamd_logger.err(task, 'cannot parse JSON: %s', err)
+ return
+ end
+ local obj = parser:get_object()
+ if (obj[1] or E).actions then
+ return process_actions(task, obj, false)
+ end
+ end
+
+ rspamd_http.request({
+ task = task,
+ url = settings.actions_url .. work_id,
+ callback = http_callback,
+ timeout = settings.http_timeout,
+ gzip = settings.gzip,
+ keepalive = settings.keepalive,
+ no_ssl_verify = settings.no_ssl_verify,
+ })
+end
+
+local function set_url_path(base, path)
+ local ts = base:sub(#base) == '/' and '' or '/'
+ return base .. ts .. path
+end
+
+local opts = rspamd_config:get_all_opt(N)
+if not opts then return end
+
+settings = lua_util.override_defaults(settings, opts)
+
+contextal_actions = lua_util.list_to_hash(contextal_actions)
+for _, k in ipairs(settings.custom_actions) do
+ contextal_actions[k] = true
+end
+
+if not settings.base_url then
+ if not (settings.submit_url and settings.actions_url) then
+ rspamd_logger.err(rspamd_config, 'no URL configured for contextal')
+ lua_util.disable_module(N, 'config')
+ return
+ end
+else
+ if not settings.submit_url then
+ settings.submit_url = set_url_path(settings.base_url, 'api/v1/submit')
+ end
+ if not settings.actions_url then
+ settings.actions_url = set_url_path(settings.base_url, 'api/v1/actions/')
+ end
+end
+
+redis_params = lua_redis.parse_redis_server(N)
+if redis_params then
+ lua_redis.register_prefix(settings.key_prefix .. '_*', N,
+ 'Cache for contextal plugin')
+end
+
+rspamd_config:register_symbol({
+ name = settings.submission_symbol,
+ priority = lua_util.symbols_priorities.top,
+ type = 'prefilter',
+ group = N,
+ callback = submit_cb
+})
+
+local id = rspamd_config:register_symbol({
+ name = settings.action_symbol_prefix,
+ type = 'postfilter',
+ priority = lua_util.symbols_priorities.high - 1,
+ group = N,
+ callback = action_cb
+})
+
+for k in pairs(contextal_actions) do
+ rspamd_config:register_symbol({
+ name = settings.action_symbol_prefix .. '_' .. k,
+ parent = id,
+ type = 'virtual',
+ group = N,
+ })
+end