|
|
@@ -0,0 +1,296 @@ |
|
|
|
--[[ |
|
|
|
Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org> |
|
|
|
Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru> |
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
you may not use this file except in compliance with the License. |
|
|
|
You may obtain a copy of the License at |
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
|
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
See the License for the specific language governing permissions and |
|
|
|
limitations under the License. |
|
|
|
]] -- |
|
|
|
|
|
|
|
if confighelp then |
|
|
|
return |
|
|
|
end |
|
|
|
|
|
|
|
local N = 'bayes_expiry' |
|
|
|
local logger = require "rspamd_logger" |
|
|
|
local mempool = require "rspamd_mempool" |
|
|
|
local util = require "rspamd_util" |
|
|
|
local lutil = require "lua_util" |
|
|
|
local lredis = require "lua_redis" |
|
|
|
|
|
|
|
local pool = mempool.create() |
|
|
|
local settings = { |
|
|
|
interval = 604800, |
|
|
|
statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'), |
|
|
|
variables = { |
|
|
|
ot_bayes_ttl = 31536000, -- one year |
|
|
|
ot_min_age = 7776000, -- 90 days |
|
|
|
ot_min_count = 5, |
|
|
|
}, |
|
|
|
symbols = {}, |
|
|
|
timeout = 60, |
|
|
|
} |
|
|
|
|
|
|
|
local VAR_NAME = 'bayes_expired' |
|
|
|
local EXPIRE_SCRIPT_TMPL = [[local result = {} |
|
|
|
local OT_BAYES_TTL = ${ot_bayes_ttl} |
|
|
|
local OT_MIN_AGE = ${ot_min_age} |
|
|
|
local OT_MIN_COUNT = ${ot_min_count} |
|
|
|
local symbol = ARGV[1] |
|
|
|
local prefixes = redis.call('SMEMBERS', symbol .. '_keys') |
|
|
|
for _, pfx in ipairs(prefixes) do |
|
|
|
local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*') |
|
|
|
local cursor, data = res[1], res[2] |
|
|
|
while data do |
|
|
|
local key_name = table.remove(data) |
|
|
|
if key_name then |
|
|
|
local h, s = redis.call('HMGET', key_name, 'H', 'S') |
|
|
|
if (h or s) then |
|
|
|
if not s then s = 0 else s = tonumber(s) end |
|
|
|
if not h then h = 0 else h = tonumber(h) end |
|
|
|
if s < OT_MIN_COUNT and h < OT_MIN_COUNT then |
|
|
|
local ttl = redis.call('TTL', key_name) |
|
|
|
if ttl > 0 then |
|
|
|
local age = OT_BAYES_TTL - ttl |
|
|
|
if age > OT_MIN_AGE then |
|
|
|
table.insert(result, key_name) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
else |
|
|
|
if cursor == "0" then |
|
|
|
data = nil |
|
|
|
else |
|
|
|
local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*') |
|
|
|
cursor, data = res[1], res[2] |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
return table.concat(result, string.char(31))]] |
|
|
|
|
|
|
|
local function configure_bayes_expiry() |
|
|
|
local opts = rspamd_config:get_all_opt(N) |
|
|
|
if not type(opts) == 'table' then return false end |
|
|
|
for k, v in pairs(opts) do |
|
|
|
settings[k] = v |
|
|
|
end |
|
|
|
if not settings.symbols[1] then |
|
|
|
logger.warn('No symbols configured, not enabling expiry') |
|
|
|
return false |
|
|
|
end |
|
|
|
return true |
|
|
|
end |
|
|
|
|
|
|
|
if not configure_bayes_expiry() then return end |
|
|
|
|
|
|
|
local function get_redis_params(ev_base, symbol) |
|
|
|
local redis_params |
|
|
|
local copts = rspamd_config:get_all_opt('classifier') |
|
|
|
if not type(copts) == 'table' then |
|
|
|
logger.errx(ev_base, "Couldn't get classifier configuration") |
|
|
|
return |
|
|
|
end |
|
|
|
if type(copts.backend) == 'table' then |
|
|
|
redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) |
|
|
|
end |
|
|
|
if redis_params then return redis_params end |
|
|
|
if type(copts.statfile) == 'table' then |
|
|
|
for _, stf in ipairs(copts.statfile) do |
|
|
|
if stf.name == symbol then |
|
|
|
redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
if redis_params then return redis_params end |
|
|
|
redis_params = lredis.rspamd_parse_redis_server(nil, copts, false) |
|
|
|
redis_params.timeout = settings.timeout |
|
|
|
return redis_params |
|
|
|
end |
|
|
|
|
|
|
|
rspamd_config:add_on_load(function (_, ev_base, worker) |
|
|
|
local processed_symbols, expire_script_sha |
|
|
|
-- Exit unless we're the first 'normal' worker |
|
|
|
if not (worker:get_name() == 'normal' and worker:get_index() == 0) then return end |
|
|
|
-- Persist mempool variable to statefile on shutdown |
|
|
|
rspamd_config:register_finish_script(function () |
|
|
|
local stamp = pool:get_variable(VAR_NAME, 'double') |
|
|
|
if not stamp then |
|
|
|
logger.warnx(ev_base, 'No last bayes expiry to persist to disk') |
|
|
|
return |
|
|
|
end |
|
|
|
local f, err = io.open(settings['statefile'], 'w') |
|
|
|
if err then |
|
|
|
logger.errx(ev_base, 'Unable to write statefile to disk: %s', err) |
|
|
|
return |
|
|
|
end |
|
|
|
if f then |
|
|
|
f:write(pool:get_variable(VAR_NAME, 'double')) |
|
|
|
f:close() |
|
|
|
end |
|
|
|
end) |
|
|
|
local expire_symbol |
|
|
|
local function load_scripts(redis_params, cont, p1, p2) |
|
|
|
local function load_script_cb(err, data) |
|
|
|
if err then |
|
|
|
logger.errx(ev_base, 'Error loading script: %s', err) |
|
|
|
else |
|
|
|
if type(data) == 'string' then |
|
|
|
expire_script_sha = data |
|
|
|
logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha) |
|
|
|
if type(cont) == 'function' then |
|
|
|
cont(p1, p2) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables) |
|
|
|
local ret = lredis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
load_script_cb, --callback |
|
|
|
'SCRIPT', -- command |
|
|
|
{'LOAD', scripttxt} |
|
|
|
) |
|
|
|
if not ret then |
|
|
|
logger.errx(ev_base, 'Error loading script') |
|
|
|
end |
|
|
|
end |
|
|
|
local function continue_expire() |
|
|
|
for _, symbol in ipairs(settings.symbols) do |
|
|
|
if not processed_symbols[symbol] then |
|
|
|
local redis_params = get_redis_params(ev_base, symbol) |
|
|
|
if not redis_params then |
|
|
|
processed_symbols[symbol] = true |
|
|
|
logger.errx(ev_base, "Couldn't get redis params") |
|
|
|
else |
|
|
|
load_scripts(redis_params, expire_symbol, redis_params, symbol) |
|
|
|
break |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
expire_symbol = function(redis_params, symbol) |
|
|
|
local function del_keys_cb(err, data) |
|
|
|
if err then |
|
|
|
logger.errx(ev_base, 'Redis request failed: %s', err) |
|
|
|
end |
|
|
|
processed_symbols[symbol] = true |
|
|
|
continue_expire() |
|
|
|
end |
|
|
|
local function get_keys_cb(err, data) |
|
|
|
if err then |
|
|
|
logger.errx(ev_base, 'Redis request failed: %s', err) |
|
|
|
processed_symbols[symbol] = true |
|
|
|
continue_expire() |
|
|
|
else |
|
|
|
if type(data) == 'string' then |
|
|
|
if data == "" then |
|
|
|
data = {} |
|
|
|
else |
|
|
|
data = lutil.rspamd_str_split(data, string.char(31)) |
|
|
|
end |
|
|
|
end |
|
|
|
if type(data) == 'table' then |
|
|
|
if not data[1] then |
|
|
|
logger.warnx(ev_base, 'No keys to delete: %s', symbol) |
|
|
|
processed_symbols[symbol] = true |
|
|
|
continue_expire() |
|
|
|
else |
|
|
|
local ret = lredis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
del_keys_cb, --callback |
|
|
|
'DEL', -- command |
|
|
|
data |
|
|
|
) |
|
|
|
if not ret then |
|
|
|
logger.errx(ev_base, 'Redis request failed') |
|
|
|
processed_symbols[symbol] = true |
|
|
|
continue_expire() |
|
|
|
end |
|
|
|
end |
|
|
|
else |
|
|
|
logger.warnx(ev_base, 'No keys to delete: %s', symbol) |
|
|
|
processed_symbols[symbol] = true |
|
|
|
continue_expire() |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
local ret = lredis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
false, -- is write |
|
|
|
get_keys_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{expire_script_sha, 0, symbol} |
|
|
|
) |
|
|
|
if not ret then |
|
|
|
logger.errx(ev_base, 'Redis request failed') |
|
|
|
processed_symbols[symbol] = true |
|
|
|
continue_expire() |
|
|
|
end |
|
|
|
end |
|
|
|
local function begin_expire(time) |
|
|
|
local stamp = time or util.get_time() |
|
|
|
pool:set_variable(VAR_NAME, stamp) |
|
|
|
processed_symbols = {} |
|
|
|
continue_expire() |
|
|
|
end |
|
|
|
-- Expire tokens at regular intervals |
|
|
|
local function schedule_regular_expiry() |
|
|
|
rspamd_config:add_periodic(ev_base, settings['interval'], function () |
|
|
|
begin_expire() |
|
|
|
return true |
|
|
|
end) |
|
|
|
end |
|
|
|
-- Expire tokens and reschedule expiry |
|
|
|
local function schedule_intermediate_expiry(when) |
|
|
|
rspamd_config:add_periodic(ev_base, when, function () |
|
|
|
begin_expire() |
|
|
|
schedule_regular_expiry() |
|
|
|
return false |
|
|
|
end) |
|
|
|
end |
|
|
|
-- Try read statefile on startup |
|
|
|
local stamp |
|
|
|
local f, err = io.open(settings['statefile'], 'r') |
|
|
|
if err then |
|
|
|
logger.warnx(ev_base, 'Failed to open statefile: %s', err) |
|
|
|
end |
|
|
|
if f then |
|
|
|
io.input(f) |
|
|
|
stamp = tonumber(io.read()) |
|
|
|
pool:set_variable(VAR_NAME, stamp) |
|
|
|
end |
|
|
|
local time = util.get_time() |
|
|
|
if not stamp then |
|
|
|
logger.debugm(N, ev_base, 'No state found - expiring stats immediately') |
|
|
|
begin_expire(time) |
|
|
|
schedule_regular_expiry() |
|
|
|
return |
|
|
|
end |
|
|
|
local delta = stamp - time + settings['interval'] |
|
|
|
if delta <= 0 then |
|
|
|
logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately') |
|
|
|
begin_expire(time) |
|
|
|
schedule_regular_expiry() |
|
|
|
return |
|
|
|
end |
|
|
|
logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta) |
|
|
|
schedule_intermediate_expiry(delta) |
|
|
|
end) |