From c30cfcb21a5d1b2fa14944ff3633198a95ea3c2d Mon Sep 17 00:00:00 2001 From: Andrew Lewis Date: Sat, 10 Jun 2017 15:45:29 +0200 Subject: [PATCH] [Feature] Bayes expiry plugin --- lualib/lua_redis.lua | 11 +- lualib/lua_util.lua | 10 + src/plugins/lua/arc.lua | 15 +- src/plugins/lua/bayes_expiry.lua | 296 ++++++++++++++++++++++++++ src/plugins/lua/dkim_signing.lua | 15 +- src/plugins/lua/metadata_exporter.lua | 15 +- 6 files changed, 320 insertions(+), 42 deletions(-) create mode 100644 src/plugins/lua/bayes_expiry.lua diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua index 42a0aacef..0dc5872fe 100644 --- a/lualib/lua_redis.lua +++ b/lualib/lua_redis.lua @@ -5,7 +5,7 @@ local exports = {} -- This function parses redis server definition using either -- specific server string for this module or global -- redis section -local function rspamd_parse_redis_server(module_name) +local function rspamd_parse_redis_server(module_name, module_opts, no_fallback) local result = {} local default_port = 6379 @@ -71,7 +71,12 @@ local function rspamd_parse_redis_server(module_name) end -- Try local options - local opts = rspamd_config:get_all_opt(module_name) + local opts + if not module_opts then + opts = rspamd_config:get_all_opt(module_name) + else + opts = module_opts + end local ret = false if opts then @@ -82,6 +87,8 @@ local function rspamd_parse_redis_server(module_name) return result end + if no_fallback then return nil end + -- Try global options opts = rspamd_config:get_all_opt('redis') diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua index 0a824dca1..1f53d51ed 100644 --- a/lualib/lua_util.lua +++ b/lualib/lua_util.lua @@ -30,4 +30,14 @@ exports.round = function(num, numDecimalPlaces) return math.floor(num * mult) / mult end +exports.template = function(tmpl, keys) + local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } + local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) } + local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") } + + local template_grammar = lpeg.Cs((var + var_braced + 1)^0) + + return lpeg.match(template_grammar, tmpl) +end + return exports diff --git a/src/plugins/lua/arc.lua b/src/plugins/lua/arc.lua index fc1c65769..a29417d74 100644 --- a/src/plugins/lua/arc.lua +++ b/src/plugins/lua/arc.lua @@ -335,19 +335,6 @@ rspamd_config:register_symbol({ rspamd_config:register_dependency(id, symbols['spf_allow_symbol']) rspamd_config:register_dependency(id, symbols['dkim_allow_symbol']) --- Signatures part -local function simple_template(tmpl, keys) - local lpeg = require "lpeg" - - local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } - local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) } - local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") } - - local template_grammar = lpeg.Cs((var + var_braced + 1)^0) - - return lpeg.match(template_grammar, tmpl) -end - local function arc_sign_seal(task, params, header) local arc_sigs = task:cache_get('arc-sigs') local arc_seals = task:cache_get('arc-seals') @@ -514,7 +501,7 @@ local function arc_signing_cb(task) end else if (p.key and p.selector) then - p.key = simple_template(p.key, {domain = p.domain, selector = p.selector}) + p.key = lua_util.template(p.key, {domain = p.domain, selector = p.selector}) local dret, hdr = dkim_sign(task, p) if dret then return arc_sign_seal(task, p, hdr) diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua new file mode 100644 index 000000000..d7df264a1 --- /dev/null +++ b/src/plugins/lua/bayes_expiry.lua @@ -0,0 +1,296 @@ +--[[ +Copyright (c) 2017, Andrew Lewis +Copyright (c) 2017, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +if confighelp then + return +end + +local N = 'bayes_expiry' +local logger = require "rspamd_logger" +local mempool = require "rspamd_mempool" +local util = require "rspamd_util" +local lutil = require "lua_util" +local lredis = require "lua_redis" + +local pool = mempool.create() +local settings = { + interval = 604800, + statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'), + variables = { + ot_bayes_ttl = 31536000, -- one year + ot_min_age = 7776000, -- 90 days + ot_min_count = 5, + }, + symbols = {}, + timeout = 60, +} + +local VAR_NAME = 'bayes_expired' +local EXPIRE_SCRIPT_TMPL = [[local result = {} +local OT_BAYES_TTL = ${ot_bayes_ttl} +local OT_MIN_AGE = ${ot_min_age} +local OT_MIN_COUNT = ${ot_min_count} +local symbol = ARGV[1] +local prefixes = redis.call('SMEMBERS', symbol .. '_keys') +for _, pfx in ipairs(prefixes) do + local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*') + local cursor, data = res[1], res[2] + while data do + local key_name = table.remove(data) + if key_name then + local h, s = redis.call('HMGET', key_name, 'H', 'S') + if (h or s) then + if not s then s = 0 else s = tonumber(s) end + if not h then h = 0 else h = tonumber(h) end + if s < OT_MIN_COUNT and h < OT_MIN_COUNT then + local ttl = redis.call('TTL', key_name) + if ttl > 0 then + local age = OT_BAYES_TTL - ttl + if age > OT_MIN_AGE then + table.insert(result, key_name) + end + end + end + end + else + if cursor == "0" then + data = nil + else + local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*') + cursor, data = res[1], res[2] + end + end + end +end +return table.concat(result, string.char(31))]] + +local function configure_bayes_expiry() + local opts = rspamd_config:get_all_opt(N) + if not type(opts) == 'table' then return false end + for k, v in pairs(opts) do + settings[k] = v + end + if not settings.symbols[1] then + logger.warn('No symbols configured, not enabling expiry') + return false + end + return true +end + +if not configure_bayes_expiry() then return end + +local function get_redis_params(ev_base, symbol) + local redis_params + local copts = rspamd_config:get_all_opt('classifier') + if not type(copts) == 'table' then + logger.errx(ev_base, "Couldn't get classifier configuration") + return + end + if type(copts.backend) == 'table' then + redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) + end + if redis_params then return redis_params end + if type(copts.statfile) == 'table' then + for _, stf in ipairs(copts.statfile) do + if stf.name == symbol then + redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) + end + end + end + if redis_params then return redis_params end + redis_params = lredis.rspamd_parse_redis_server(nil, copts, false) + redis_params.timeout = settings.timeout + return redis_params +end + +rspamd_config:add_on_load(function (_, ev_base, worker) + local processed_symbols, expire_script_sha + -- Exit unless we're the first 'normal' worker + if not (worker:get_name() == 'normal' and worker:get_index() == 0) then return end + -- Persist mempool variable to statefile on shutdown + rspamd_config:register_finish_script(function () + local stamp = pool:get_variable(VAR_NAME, 'double') + if not stamp then + logger.warnx(ev_base, 'No last bayes expiry to persist to disk') + return + end + local f, err = io.open(settings['statefile'], 'w') + if err then + logger.errx(ev_base, 'Unable to write statefile to disk: %s', err) + return + end + if f then + f:write(pool:get_variable(VAR_NAME, 'double')) + f:close() + end + end) + local expire_symbol + local function load_scripts(redis_params, cont, p1, p2) + local function load_script_cb(err, data) + if err then + logger.errx(ev_base, 'Error loading script: %s', err) + else + if type(data) == 'string' then + expire_script_sha = data + logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha) + if type(cont) == 'function' then + cont(p1, p2) + end + end + end + end + local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables) + local ret = lredis.redis_make_request_taskless(ev_base, + rspamd_config, + redis_params, + nil, + true, -- is write + load_script_cb, --callback + 'SCRIPT', -- command + {'LOAD', scripttxt} + ) + if not ret then + logger.errx(ev_base, 'Error loading script') + end + end + local function continue_expire() + for _, symbol in ipairs(settings.symbols) do + if not processed_symbols[symbol] then + local redis_params = get_redis_params(ev_base, symbol) + if not redis_params then + processed_symbols[symbol] = true + logger.errx(ev_base, "Couldn't get redis params") + else + load_scripts(redis_params, expire_symbol, redis_params, symbol) + break + end + end + end + end + expire_symbol = function(redis_params, symbol) + local function del_keys_cb(err, data) + if err then + logger.errx(ev_base, 'Redis request failed: %s', err) + end + processed_symbols[symbol] = true + continue_expire() + end + local function get_keys_cb(err, data) + if err then + logger.errx(ev_base, 'Redis request failed: %s', err) + processed_symbols[symbol] = true + continue_expire() + else + if type(data) == 'string' then + if data == "" then + data = {} + else + data = lutil.rspamd_str_split(data, string.char(31)) + end + end + if type(data) == 'table' then + if not data[1] then + logger.warnx(ev_base, 'No keys to delete: %s', symbol) + processed_symbols[symbol] = true + continue_expire() + else + local ret = lredis.redis_make_request_taskless(ev_base, + rspamd_config, + redis_params, + nil, + true, -- is write + del_keys_cb, --callback + 'DEL', -- command + data + ) + if not ret then + logger.errx(ev_base, 'Redis request failed') + processed_symbols[symbol] = true + continue_expire() + end + end + else + logger.warnx(ev_base, 'No keys to delete: %s', symbol) + processed_symbols[symbol] = true + continue_expire() + end + end + end + local ret = lredis.redis_make_request_taskless(ev_base, + rspamd_config, + redis_params, + nil, + false, -- is write + get_keys_cb, --callback + 'EVALSHA', -- command + {expire_script_sha, 0, symbol} + ) + if not ret then + logger.errx(ev_base, 'Redis request failed') + processed_symbols[symbol] = true + continue_expire() + end + end + local function begin_expire(time) + local stamp = time or util.get_time() + pool:set_variable(VAR_NAME, stamp) + processed_symbols = {} + continue_expire() + end + -- Expire tokens at regular intervals + local function schedule_regular_expiry() + rspamd_config:add_periodic(ev_base, settings['interval'], function () + begin_expire() + return true + end) + end + -- Expire tokens and reschedule expiry + local function schedule_intermediate_expiry(when) + rspamd_config:add_periodic(ev_base, when, function () + begin_expire() + schedule_regular_expiry() + return false + end) + end + -- Try read statefile on startup + local stamp + local f, err = io.open(settings['statefile'], 'r') + if err then + logger.warnx(ev_base, 'Failed to open statefile: %s', err) + end + if f then + io.input(f) + stamp = tonumber(io.read()) + pool:set_variable(VAR_NAME, stamp) + end + local time = util.get_time() + if not stamp then + logger.debugm(N, ev_base, 'No state found - expiring stats immediately') + begin_expire(time) + schedule_regular_expiry() + return + end + local delta = stamp - time + settings['interval'] + if delta <= 0 then + logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately') + begin_expire(time) + schedule_regular_expiry() + return + end + logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta) + schedule_intermediate_expiry(delta) +end) diff --git a/src/plugins/lua/dkim_signing.lua b/src/plugins/lua/dkim_signing.lua index 78505bf92..edb4db2f6 100644 --- a/src/plugins/lua/dkim_signing.lua +++ b/src/plugins/lua/dkim_signing.lua @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. ]]-- +local lutil = require "lua_util" local rspamd_logger = require "rspamd_logger" local dkim_sign_tools = require "dkim_sign_tools" @@ -46,18 +47,6 @@ local N = 'dkim_signing' local redis_params local sign_func = rspamd_plugins.dkim.sign -local function simple_template(tmpl, keys) - local lpeg = require "lpeg" - - local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } - local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) } - local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") } - - local template_grammar = lpeg.Cs((var + var_braced + 1)^0) - - return lpeg.match(template_grammar, tmpl) -end - local function dkim_signing_cb(task) local ret,p = dkim_sign_tools.prepare_dkim_signing(N, task, settings) @@ -123,7 +112,7 @@ local function dkim_signing_cb(task) end else if (p.key and p.selector) then - p.key = simple_template(p.key, {domain = p.domain, selector = p.selector}) + p.key = lutil.template(p.key, {domain = p.domain, selector = p.selector}) local sret, _ = sign_func(task, p) return sret else diff --git a/src/plugins/lua/metadata_exporter.lua b/src/plugins/lua/metadata_exporter.lua index f0b3f175d..2268c86f5 100644 --- a/src/plugins/lua/metadata_exporter.lua +++ b/src/plugins/lua/metadata_exporter.lua @@ -22,6 +22,7 @@ end -- A plugin that pushes metadata (or whole messages) to external services local redis_params +local lutil = require "lua_util" local rspamd_http = require "rspamd_http" local rspamd_tcp = require "rspamd_tcp" local rspamd_util = require "rspamd_util" @@ -136,18 +137,6 @@ local function get_general_metadata(task, flatten, no_content) return r end -local function simple_template(tmpl, keys) - local lpeg = require "lpeg" - - local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } - local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) } - local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") } - - local template_grammar = lpeg.Cs((var + var_braced + 1)^0) - - return lpeg.match(template_grammar, tmpl) -end - local formatters = { default = function(task) return task:get_content() @@ -158,7 +147,7 @@ local formatters = { meta.mail_to = rule.mail_to or settings.mail_to meta.our_message_id = rspamd_util.random_hex(12) .. '@rspamd' meta.date = rspamd_util.time_to_string(rspamd_util.get_time()) - return simple_template(rule.email_template or settings.email_template, meta) + return lutil.template(rule.email_template or settings.email_template, meta) end, json = function(task) return ucl.to_format(get_general_metadata(task), 'json-compact') -- 2.39.5