aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorAndrew Lewis <nerf@judo.za.org>2017-06-10 15:45:29 +0200
committerAndrew Lewis <nerf@judo.za.org>2017-06-10 15:45:29 +0200
commitc30cfcb21a5d1b2fa14944ff3633198a95ea3c2d (patch)
treea947837bbd92ed3bdf4a22b54dd6b492e9500246 /src
parentcd0cc6187ed4b153e3506bbf79d28aa760a85f4a (diff)
downloadrspamd-c30cfcb21a5d1b2fa14944ff3633198a95ea3c2d.tar.gz
rspamd-c30cfcb21a5d1b2fa14944ff3633198a95ea3c2d.zip
[Feature] Bayes expiry plugin
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/arc.lua15
-rw-r--r--src/plugins/lua/bayes_expiry.lua296
-rw-r--r--src/plugins/lua/dkim_signing.lua15
-rw-r--r--src/plugins/lua/metadata_exporter.lua15
4 files changed, 301 insertions, 40 deletions
diff --git a/src/plugins/lua/arc.lua b/src/plugins/lua/arc.lua
index fc1c65769..a29417d74 100644
--- a/src/plugins/lua/arc.lua
+++ b/src/plugins/lua/arc.lua
@@ -335,19 +335,6 @@ rspamd_config:register_symbol({
rspamd_config:register_dependency(id, symbols['spf_allow_symbol'])
rspamd_config:register_dependency(id, symbols['dkim_allow_symbol'])
--- Signatures part
-local function simple_template(tmpl, keys)
- local lpeg = require "lpeg"
-
- local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
- local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
- local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
-
- local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
-
- return lpeg.match(template_grammar, tmpl)
-end
-
local function arc_sign_seal(task, params, header)
local arc_sigs = task:cache_get('arc-sigs')
local arc_seals = task:cache_get('arc-seals')
@@ -514,7 +501,7 @@ local function arc_signing_cb(task)
end
else
if (p.key and p.selector) then
- p.key = simple_template(p.key, {domain = p.domain, selector = p.selector})
+ p.key = lua_util.template(p.key, {domain = p.domain, selector = p.selector})
local dret, hdr = dkim_sign(task, p)
if dret then
return arc_sign_seal(task, p, hdr)
diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua
new file mode 100644
index 000000000..d7df264a1
--- /dev/null
+++ b/src/plugins/lua/bayes_expiry.lua
@@ -0,0 +1,296 @@
+--[[
+Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org>
+Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]] --
+
+if confighelp then
+ return
+end
+
+local N = 'bayes_expiry'
+local logger = require "rspamd_logger"
+local mempool = require "rspamd_mempool"
+local util = require "rspamd_util"
+local lutil = require "lua_util"
+local lredis = require "lua_redis"
+
+local pool = mempool.create()
+local settings = {
+ interval = 604800,
+ statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'),
+ variables = {
+ ot_bayes_ttl = 31536000, -- one year
+ ot_min_age = 7776000, -- 90 days
+ ot_min_count = 5,
+ },
+ symbols = {},
+ timeout = 60,
+}
+
+local VAR_NAME = 'bayes_expired'
+local EXPIRE_SCRIPT_TMPL = [[local result = {}
+local OT_BAYES_TTL = ${ot_bayes_ttl}
+local OT_MIN_AGE = ${ot_min_age}
+local OT_MIN_COUNT = ${ot_min_count}
+local symbol = ARGV[1]
+local prefixes = redis.call('SMEMBERS', symbol .. '_keys')
+for _, pfx in ipairs(prefixes) do
+ local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*')
+ local cursor, data = res[1], res[2]
+ while data do
+ local key_name = table.remove(data)
+ if key_name then
+ local h, s = redis.call('HMGET', key_name, 'H', 'S')
+ if (h or s) then
+ if not s then s = 0 else s = tonumber(s) end
+ if not h then h = 0 else h = tonumber(h) end
+ if s < OT_MIN_COUNT and h < OT_MIN_COUNT then
+ local ttl = redis.call('TTL', key_name)
+ if ttl > 0 then
+ local age = OT_BAYES_TTL - ttl
+ if age > OT_MIN_AGE then
+ table.insert(result, key_name)
+ end
+ end
+ end
+ end
+ else
+ if cursor == "0" then
+ data = nil
+ else
+ local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*')
+ cursor, data = res[1], res[2]
+ end
+ end
+ end
+end
+return table.concat(result, string.char(31))]]
+
+local function configure_bayes_expiry()
+ local opts = rspamd_config:get_all_opt(N)
+ if not type(opts) == 'table' then return false end
+ for k, v in pairs(opts) do
+ settings[k] = v
+ end
+ if not settings.symbols[1] then
+ logger.warn('No symbols configured, not enabling expiry')
+ return false
+ end
+ return true
+end
+
+if not configure_bayes_expiry() then return end
+
+local function get_redis_params(ev_base, symbol)
+ local redis_params
+ local copts = rspamd_config:get_all_opt('classifier')
+ if not type(copts) == 'table' then
+ logger.errx(ev_base, "Couldn't get classifier configuration")
+ return
+ end
+ if type(copts.backend) == 'table' then
+ redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+ end
+ if redis_params then return redis_params end
+ if type(copts.statfile) == 'table' then
+ for _, stf in ipairs(copts.statfile) do
+ if stf.name == symbol then
+ redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+ end
+ end
+ end
+ if redis_params then return redis_params end
+ redis_params = lredis.rspamd_parse_redis_server(nil, copts, false)
+ redis_params.timeout = settings.timeout
+ return redis_params
+end
+
+rspamd_config:add_on_load(function (_, ev_base, worker)
+ local processed_symbols, expire_script_sha
+ -- Exit unless we're the first 'normal' worker
+ if not (worker:get_name() == 'normal' and worker:get_index() == 0) then return end
+ -- Persist mempool variable to statefile on shutdown
+ rspamd_config:register_finish_script(function ()
+ local stamp = pool:get_variable(VAR_NAME, 'double')
+ if not stamp then
+ logger.warnx(ev_base, 'No last bayes expiry to persist to disk')
+ return
+ end
+ local f, err = io.open(settings['statefile'], 'w')
+ if err then
+ logger.errx(ev_base, 'Unable to write statefile to disk: %s', err)
+ return
+ end
+ if f then
+ f:write(pool:get_variable(VAR_NAME, 'double'))
+ f:close()
+ end
+ end)
+ local expire_symbol
+ local function load_scripts(redis_params, cont, p1, p2)
+ local function load_script_cb(err, data)
+ if err then
+ logger.errx(ev_base, 'Error loading script: %s', err)
+ else
+ if type(data) == 'string' then
+ expire_script_sha = data
+ logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha)
+ if type(cont) == 'function' then
+ cont(p1, p2)
+ end
+ end
+ end
+ end
+ local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables)
+ local ret = lredis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ redis_params,
+ nil,
+ true, -- is write
+ load_script_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', scripttxt}
+ )
+ if not ret then
+ logger.errx(ev_base, 'Error loading script')
+ end
+ end
+ local function continue_expire()
+ for _, symbol in ipairs(settings.symbols) do
+ if not processed_symbols[symbol] then
+ local redis_params = get_redis_params(ev_base, symbol)
+ if not redis_params then
+ processed_symbols[symbol] = true
+ logger.errx(ev_base, "Couldn't get redis params")
+ else
+ load_scripts(redis_params, expire_symbol, redis_params, symbol)
+ break
+ end
+ end
+ end
+ end
+ expire_symbol = function(redis_params, symbol)
+ local function del_keys_cb(err, data)
+ if err then
+ logger.errx(ev_base, 'Redis request failed: %s', err)
+ end
+ processed_symbols[symbol] = true
+ continue_expire()
+ end
+ local function get_keys_cb(err, data)
+ if err then
+ logger.errx(ev_base, 'Redis request failed: %s', err)
+ processed_symbols[symbol] = true
+ continue_expire()
+ else
+ if type(data) == 'string' then
+ if data == "" then
+ data = {}
+ else
+ data = lutil.rspamd_str_split(data, string.char(31))
+ end
+ end
+ if type(data) == 'table' then
+ if not data[1] then
+ logger.warnx(ev_base, 'No keys to delete: %s', symbol)
+ processed_symbols[symbol] = true
+ continue_expire()
+ else
+ local ret = lredis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ redis_params,
+ nil,
+ true, -- is write
+ del_keys_cb, --callback
+ 'DEL', -- command
+ data
+ )
+ if not ret then
+ logger.errx(ev_base, 'Redis request failed')
+ processed_symbols[symbol] = true
+ continue_expire()
+ end
+ end
+ else
+ logger.warnx(ev_base, 'No keys to delete: %s', symbol)
+ processed_symbols[symbol] = true
+ continue_expire()
+ end
+ end
+ end
+ local ret = lredis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ redis_params,
+ nil,
+ false, -- is write
+ get_keys_cb, --callback
+ 'EVALSHA', -- command
+ {expire_script_sha, 0, symbol}
+ )
+ if not ret then
+ logger.errx(ev_base, 'Redis request failed')
+ processed_symbols[symbol] = true
+ continue_expire()
+ end
+ end
+ local function begin_expire(time)
+ local stamp = time or util.get_time()
+ pool:set_variable(VAR_NAME, stamp)
+ processed_symbols = {}
+ continue_expire()
+ end
+ -- Expire tokens at regular intervals
+ local function schedule_regular_expiry()
+ rspamd_config:add_periodic(ev_base, settings['interval'], function ()
+ begin_expire()
+ return true
+ end)
+ end
+ -- Expire tokens and reschedule expiry
+ local function schedule_intermediate_expiry(when)
+ rspamd_config:add_periodic(ev_base, when, function ()
+ begin_expire()
+ schedule_regular_expiry()
+ return false
+ end)
+ end
+ -- Try read statefile on startup
+ local stamp
+ local f, err = io.open(settings['statefile'], 'r')
+ if err then
+ logger.warnx(ev_base, 'Failed to open statefile: %s', err)
+ end
+ if f then
+ io.input(f)
+ stamp = tonumber(io.read())
+ pool:set_variable(VAR_NAME, stamp)
+ end
+ local time = util.get_time()
+ if not stamp then
+ logger.debugm(N, ev_base, 'No state found - expiring stats immediately')
+ begin_expire(time)
+ schedule_regular_expiry()
+ return
+ end
+ local delta = stamp - time + settings['interval']
+ if delta <= 0 then
+ logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately')
+ begin_expire(time)
+ schedule_regular_expiry()
+ return
+ end
+ logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta)
+ schedule_intermediate_expiry(delta)
+end)
diff --git a/src/plugins/lua/dkim_signing.lua b/src/plugins/lua/dkim_signing.lua
index 78505bf92..edb4db2f6 100644
--- a/src/plugins/lua/dkim_signing.lua
+++ b/src/plugins/lua/dkim_signing.lua
@@ -15,6 +15,7 @@ See the License for the specific language governing permissions and
limitations under the License.
]]--
+local lutil = require "lua_util"
local rspamd_logger = require "rspamd_logger"
local dkim_sign_tools = require "dkim_sign_tools"
@@ -46,18 +47,6 @@ local N = 'dkim_signing'
local redis_params
local sign_func = rspamd_plugins.dkim.sign
-local function simple_template(tmpl, keys)
- local lpeg = require "lpeg"
-
- local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
- local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
- local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
-
- local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
-
- return lpeg.match(template_grammar, tmpl)
-end
-
local function dkim_signing_cb(task)
local ret,p = dkim_sign_tools.prepare_dkim_signing(N, task, settings)
@@ -123,7 +112,7 @@ local function dkim_signing_cb(task)
end
else
if (p.key and p.selector) then
- p.key = simple_template(p.key, {domain = p.domain, selector = p.selector})
+ p.key = lutil.template(p.key, {domain = p.domain, selector = p.selector})
local sret, _ = sign_func(task, p)
return sret
else
diff --git a/src/plugins/lua/metadata_exporter.lua b/src/plugins/lua/metadata_exporter.lua
index f0b3f175d..2268c86f5 100644
--- a/src/plugins/lua/metadata_exporter.lua
+++ b/src/plugins/lua/metadata_exporter.lua
@@ -22,6 +22,7 @@ end
-- A plugin that pushes metadata (or whole messages) to external services
local redis_params
+local lutil = require "lua_util"
local rspamd_http = require "rspamd_http"
local rspamd_tcp = require "rspamd_tcp"
local rspamd_util = require "rspamd_util"
@@ -136,18 +137,6 @@ local function get_general_metadata(task, flatten, no_content)
return r
end
-local function simple_template(tmpl, keys)
- local lpeg = require "lpeg"
-
- local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
- local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
- local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
-
- local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
-
- return lpeg.match(template_grammar, tmpl)
-end
-
local formatters = {
default = function(task)
return task:get_content()
@@ -158,7 +147,7 @@ local formatters = {
meta.mail_to = rule.mail_to or settings.mail_to
meta.our_message_id = rspamd_util.random_hex(12) .. '@rspamd'
meta.date = rspamd_util.time_to_string(rspamd_util.get_time())
- return simple_template(rule.email_template or settings.email_template, meta)
+ return lutil.template(rule.email_template or settings.email_template, meta)
end,
json = function(task)
return ucl.to_format(get_general_metadata(task), 'json-compact')