diff options
author | Andrew Lewis <nerf@judo.za.org> | 2025-02-27 22:23:32 +0200 |
---|---|---|
committer | Andrew Lewis <nerf@judo.za.org> | 2025-03-24 13:54:51 +0200 |
commit | 9f0a5321c55138f169d8545c84b21791b0c71788 (patch) | |
tree | ce70c3ad221ec20b4dc2cdf8f5dd9872d5d86456 /src | |
parent | 42d4e3d049b28e3c0e03a63b6eb468b58c0b92cf (diff) | |
download | rspamd-9f0a5321c55138f169d8545c84b21791b0c71788.tar.gz rspamd-9f0a5321c55138f169d8545c84b21791b0c71788.zip |
[Feature] Plugin to integrate with Contextal platform
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/contextal.lua | 302 |
1 files changed, 302 insertions, 0 deletions
diff --git a/src/plugins/lua/contextal.lua b/src/plugins/lua/contextal.lua new file mode 100644 index 000000000..341b7a125 --- /dev/null +++ b/src/plugins/lua/contextal.lua @@ -0,0 +1,302 @@ +--[[ +Copyright (c) 2025, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local E = {} +local N = 'contextal' + +if confighelp then + return +end + +local opts = rspamd_config:get_all_opt(N) +if not opts then + return +end + +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local ucl = require "ucl" + +local contextal_actions = { + 'ALERT', + 'ALLOW', + 'BLOCK', + 'QUARANTINE', + 'SPAM', +} + +local settings = { + action_symbol_prefix = 'CONTEXTAL_ACTION', + base_url = 'http://localhost:8080', + cache_ttl = 3600, + custom_actions = {}, + http_timeout = 2, + key_prefix = 'CXAL', + request_ttl = 4, + submission_symbol = 'CONTEXTAL_SUBMIT', +} + +local static_boundary = rspamd_util.random_hex(32) + +local function cache_key(task) + return string.format('%s_%s', settings.key_prefix, task:get_digest()) +end + +local function process_actions(task, obj, is_cached) + for _, match in ipairs((obj[1] or E).actions) do + local act = match.action + local scenario = match.scenario + if not (act and scenario) then + rspamd_logger.err(task, 'bad result: %s', match) + elseif contextal_actions[act] then + task:insert_result(settings.action_symbol_prefix .. '_' .. act, 1.0, scenario) + else + rspamd_logger.err(task, 'unknown action: %s', act) + end + end + + if not redis_params or is_cached then return end + + local cache_obj + if (obj[1] or E).actions then + cache_obj = {[1] = {["actions"] = obj[1].actions}} + elseif (obj[1] or E).work_id then + cache_obj = {[1] = {["work_id"] = obj[1].work_id}} + else + rspamd_logger.err(task, 'bad result: %s', obj) + return + end + + local key = cache_key(task) + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_get_cb, --callback + 'SET', -- command + { key, ucl.to_format(cache_obj, 'json-compact') } -- arguments + ) + + if not ret then + rspamd_logger.err(task, 'cannot make redis request to cache result') + return + end +end + +local function process_cached(task, txt) + local parser = ucl.parser() + local _, err = parser:parse_string(txt) + if err then + rspamd_logger.err(task, 'cannot parse JSON (cached): %s', err) + return + end + local obj = parser:get_object() + if (obj[1] or E).actions then + task:disable_symbol(settings.action_symbol_prefix) + return process_actions(task, obj, true) + elseif (obj[1] or E).work_id then + task:get_mempool():set_variable('contextal_work_id', obj.work_id) + else + rspamd_logger.err(task, 'bad result (cached): %s', obj) + end +end + +local function submit(task) + + local function http_callback(err, code, body, hdrs) + if err then + rspamd_logger.err(task, 'http error: %s', err) + return + end + if code ~= 201 then + rspamd_logger.err(task, 'bad http code: %s', code) + return + end + local parser = ucl.parser() + local _, err = parser:parse_string(body) + if err then + rspamd_logger.err(task, 'cannot parse JSON: %s', err) + return + end + local obj = parser:get_object() + local work_id = obj.work_id + if work_id then + task:get_mempool():set_variable('contextal_work_id', work_id) + end + task:insert_result(settings.submission_symbol, 1.0, + string.format('work_id=%s', work_id or 'nil')) + end + + local req = { + object_data = {['data'] = task:get_content()}, + } + if settings.request_ttl then + req.ttl = {['data'] = tostring(settings.request_ttl)} + end + if settings.max_recursion then + req.maxrec = {['data'] = tostring(settings.max_recursion)} + end + rspamd_http.request({ + task = task, + url = settings.submit_url, + body = lua_util.table_to_multipart_body(req, static_boundary), + callback = http_callback, + headers = { + ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary) + }, + timeout = settings.http_timeout, + gzip = settings.gzip, + keepalive = settings.keepalive, + no_ssl_verify = settings.no_ssl_verify, + }) +end + +local function submit_cb(task) + if redis_params then + + local function redis_get_cb(err, data) + if err then + rspamd_logger.err(task, 'error querying redis: %s', err) + return + end + if type(data) == 'userdata' then + return submit(task) + end + process_cached(task, data) + end + + local key = cache_key(task) + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_get_cb, --callback + 'GET', -- command + { key } -- arguments + ) + + if not ret then + rspamd_logger.err(task, 'cannot make redis request to check results') + return + end + + else + return submit(task) + end +end + +local function action_cb(task) + local work_id = task:get_mempool():get_variable('contextal_work_id', 'string') + if not work_id then + rspamd_logger.err(task, 'no work id found in mempool') + return + end + + local function http_callback(err, code, body, hdrs) + if err then + rspamd_logger.err(task, 'http error: %s', err) + return + end + if code ~= 200 then + rspamd_logger.err(task, 'bad http code: %s', code) + return + end + local parser = ucl.parser() + local _, err = parser:parse_string(body) + if err then + rspamd_logger.err(task, 'cannot parse JSON: %s', err) + return + end + local obj = parser:get_object() + if (obj[1] or E).actions then + return process_actions(task, obj, false) + end + end + + rspamd_http.request({ + task = task, + url = settings.actions_url .. work_id, + callback = http_callback, + timeout = settings.http_timeout, + gzip = settings.gzip, + keepalive = settings.keepalive, + no_ssl_verify = settings.no_ssl_verify, + }) +end + +local function set_url_path(base, path) + local ts = base:sub(#base) == '/' and '' or '/' + return base .. ts .. path +end + +local opts = rspamd_config:get_all_opt(N) +if not opts then return end + +settings = lua_util.override_defaults(settings, opts) + +contextal_actions = lua_util.list_to_hash(contextal_actions) +for _, k in ipairs(settings.custom_actions) do + contextal_actions[k] = true +end + +if not settings.base_url then + if not (settings.submit_url and settings.actions_url) then + rspamd_logger.err(rspamd_config, 'no URL configured for contextal') + lua_util.disable_module(N, 'config') + return + end +else + if not settings.submit_url then + settings.submit_url = set_url_path(settings.base_url, 'api/v1/submit') + end + if not settings.actions_url then + settings.actions_url = set_url_path(settings.base_url, 'api/v1/actions/') + end +end + +redis_params = lua_redis.parse_redis_server(N) +if redis_params then + lua_redis.register_prefix(settings.key_prefix .. '_*', N, + 'Cache for contextal plugin') +end + +rspamd_config:register_symbol({ + name = settings.submission_symbol, + priority = lua_util.symbols_priorities.top, + type = 'prefilter', + group = N, + callback = submit_cb +}) + +local id = rspamd_config:register_symbol({ + name = settings.action_symbol_prefix, + type = 'postfilter', + priority = lua_util.symbols_priorities.high - 1, + group = N, + callback = action_cb +}) + +for k in pairs(contextal_actions) do + rspamd_config:register_symbol({ + name = settings.action_symbol_prefix .. '_' .. k, + parent = id, + type = 'virtual', + group = N, + }) +end |