aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lualib/lua_maps.lua83
1 files changed, 77 insertions, 6 deletions
diff --git a/lualib/lua_maps.lua b/lualib/lua_maps.lua
index 6dad3b6ad..3a2b29d30 100644
--- a/lualib/lua_maps.lua
+++ b/lualib/lua_maps.lua
@@ -88,16 +88,64 @@ end
local external_map_schema = ts.shape {
external = ts.equivalent(true), -- must be true
- backend = ts.string, -- where to get data, required
+ backend = ts.string:is_optional(), -- where to get data, required for HTTP
+ cdb = ts.string:is_optional(), -- path to CDB file, required for CDB
method = ts.one_of { "body", "header", "query" }, -- how to pass input
encode = ts.one_of { "json", "messagepack" }:is_optional(), -- how to encode input (if relevant)
timeout = (ts.number + ts.string / lua_util.parse_time_interval):is_optional(),
}
+-- Storage for CDB instances
+local cdb_maps = {}
+local cdb_finisher_set = false
+
local rspamd_http = require "rspamd_http"
local ucl = require "ucl"
+-- Function to handle CDB maps
+local function handle_cdb_map(map_config, key, callback, task)
+ local rspamd_cdb = require "rspamd_cdb"
+ local hash_key = map_config.cdb
+
+ -- Check if we need to open the CDB file
+ if not cdb_maps[hash_key] then
+ local cdb_file = map_config.cdb
+ -- Provide ev_base to monitor changes
+ local cdb_handle = rspamd_cdb.open(cdb_file, task:get_ev_base())
+
+ if not cdb_handle then
+ local err_msg = string.format("Failed to open CDB file: %s", cdb_file)
+ rspamd_logger.errx(task, err_msg)
+ if callback then
+ callback(false, err_msg, 500, task)
+ end
+ return nil
+ else
+ cdb_maps[hash_key] = cdb_handle
+ end
+ end
+
+ -- Look up the key in CDB
+ local result = cdb_maps[hash_key]:find(key)
+
+ if callback then
+ if result then
+ callback(true, result, 200, task)
+ else
+ callback(false, 'not found', 404, task)
+ end
+ return nil
+ end
+
+ return result
+end
+
local function query_external_map(map_config, upstreams, key, callback, task)
+ -- Check if this is a CDB map
+ if map_config.cdb then
+ return handle_cdb_map(map_config, key, callback, task)
+ end
+ -- Fallback to HTTP
local http_method = (map_config.method == 'body' or map_config.method == 'form') and 'POST' or 'GET'
local upstream = upstreams:get_upstream_round_robin()
local http_headers = {
@@ -138,7 +186,8 @@ local function query_external_map(map_config, upstreams, key, callback, task)
local params_table = {}
for k, v in pairs(key) do
if type(v) == 'string' then
- table.insert(params_table, string.format('%s=%s', lua_util.url_encode_string(k), lua_util.url_encode_string(v)))
+ table.insert(params_table,
+ string.format('%s=%s', lua_util.url_encode_string(k), lua_util.url_encode_string(v)))
end
end
url = string.format('%s?%s', url, table.concat(params_table, '&'))
@@ -448,17 +497,39 @@ local function rspamd_map_add_from_ucl(opt, mtype, description, callback)
local parse_res, parse_err = external_map_schema(opt)
if parse_res then
- ret.__upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), opt.backend)
- if ret.__upstreams then
+ if opt.cdb then
ret.__data = opt
ret.__external = true
setmetatable(ret, ret_mt)
maybe_register_selector()
+ if not cdb_finisher_set then
+ -- Register a finalize script to close all CDB handles when Rspamd stops
+ rspamd_config:register_finish_script(function()
+ for path, _ in pairs(cdb_maps) do
+ rspamd_logger.infox(rspamd_config, 'closing CDB map: %s', path)
+ cdb_maps[path] = nil
+ end
+ end)
+ cdb_finisher_set = true
+ end
+
return ret
+ elseif opt.backend then
+ ret.__upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), opt.backend)
+ if ret.__upstreams then
+ ret.__data = opt
+ ret.__external = true
+ setmetatable(ret, ret_mt)
+ maybe_register_selector()
+
+ return ret
+ else
+ rspamd_logger.errx(rspamd_config, 'cannot parse external map upstreams: %s',
+ opt.backend)
+ end
else
- rspamd_logger.errx(rspamd_config, 'cannot parse external map upstreams: %s',
- opt.backend)
+ rspamd_logger.errx(rspamd_config, 'external map requires either "cdb" or "backend" parameter')
end
else
rspamd_logger.errx(rspamd_config, 'cannot parse external map: %s',