end
local N = 'bayes_expiry'
+local E = {}
local logger = require "rspamd_logger"
-local mempool = require "rspamd_mempool"
-local util = require "rspamd_util"
local lutil = require "lua_util"
local lredis = require "lua_redis"
-local pool = mempool.create()
local settings = {
- interval = 604800,
- statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'),
- variables = {
- ot_bayes_ttl = 31536000, -- one year
- ot_min_age = 7776000, -- 90 days
- ot_min_count = 5,
- },
- symbols = {},
- timeout = 60,
+ interval = 60, -- one iteration step per minute
+ count = 1000, -- check up to 1000 keys on each iteration
+ threshold = 10, -- require at least 10 occurrences to increase expire
+ epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon
+ common_ttl_divisor = 100, -- how should we discriminate common elements
+ significant_factor = 3.0 / 4.0, -- which tokens should we update
+ classifiers = {},
}
-local VAR_NAME = 'bayes_expired'
-local EXPIRE_SCRIPT_TMPL = [[local result = {}
-local OT_BAYES_TTL = ${ot_bayes_ttl}
-local OT_MIN_AGE = ${ot_min_age}
-local OT_MIN_COUNT = ${ot_min_count}
-local symbol = ARGV[1]
-local prefixes = redis.call('SMEMBERS', symbol .. '_keys')
-for _, pfx in ipairs(prefixes) do
- local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*')
- local cursor, data = res[1], res[2]
- while data do
- local key_name = table.remove(data)
- if key_name then
- local h, s = redis.call('HMGET', key_name, 'H', 'S')
- if (h or s) then
- if not s then s = 0 else s = tonumber(s) end
- if not h then h = 0 else h = tonumber(h) end
- if s < OT_MIN_COUNT and h < OT_MIN_COUNT then
- local ttl = redis.call('TTL', key_name)
- if ttl > 0 then
- local age = OT_BAYES_TTL - ttl
- if age > OT_MIN_AGE then
- table.insert(result, key_name)
- end
- end
+local template = {
+
+}
+
+local function check_redis_classifier(cls, cfg)
+ -- Skip old classifiers
+ if cls.new_schema then
+ local symbol_spam, symbol_ham
+ local expiry = (cls.expiry or cls.expire)
+ -- Load symbols from statfiles
+ local statfiles = cls.statfile
+ for _,stf in ipairs(statfiles) do
+ local symbol = stf.symbol or 'undefined'
+
+ local spam
+ if stf.spam then
+ spam = stf.spam
+ else
+ if string.match(symbol:upper(), 'SPAM') then
+ spam = true
+ else
+ spam = false
end
end
- else
- if cursor == "0" then
- data = nil
+
+ if spam then
+ symbol_spam = symbol
else
- local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*')
- cursor, data = res[1], res[2]
+ symbol_ham = symbol
+ end
+ end
+
+ if not symbol_spam or not symbol_ham or not expiry then
+ return
+ end
+ -- Now try to load redis_params if needed
+
+ local redis_params = {}
+ if not lredis.try_load_redis_servers(cls, rspamd_config, redis_params) then
+ if not lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, redis_params) then
+ if not lredis.try_load_redis_servers(cfg['redis'] or E, rspamd_config, redis_params) then
+ return false
+ end
end
end
+
+ table.insert(settings.classifiers, {
+ symbol_spam = symbol_spam,
+ symbol_ham = symbol_ham,
+ redis_params = redis_params,
+ expiry = expiry
+ })
end
end
-return table.concat(result, string.char(31))]]
-local function configure_bayes_expiry()
- local opts = rspamd_config:get_all_opt(N)
- if not type(opts) == 'table' then return false end
- for k, v in pairs(opts) do
- settings[k] = v
+-- Check classifiers and try find the appropriate ones
+local obj = rspamd_config:get_ucl()
+
+local classifier = obj.classifier
+
+if classifier then
+ if classifier[1] then
+ for _,cls in ipairs(classifier) do
+ if cls.bayes then cls = cls.bayes end
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, obj)
+ end
+ end
+ else
+ if classifier.bayes then
+
+ classifier = classifier.bayes
+ if classifier[1] then
+ for _,cls in ipairs(classifier) do
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, obj)
+ end
+ end
+ else
+ if classifier.backend and classifier.backend == 'redis' then
+ check_redis_classifier(classifier, obj)
+ end
+ end
+ end
end
- if not settings.symbols[1] then
- logger.warn('No symbols configured, not enabling expiry')
- return false
+end
+
+
+local opts = rspamd_config:get_all_opt(N)
+
+if opts then
+ for k,v in pairs(opts) do
+ settings[k] = v
end
- return true
end
-if not configure_bayes_expiry() then
- lutil.disable_module(N, 'config')
- return
+-- Fill template
+template.count = settings.count
+template.threshold = settings.threshold
+template.common_ttl_divisor = settings.common_ttl_divisor
+template.epsilon_common = settings.epsilon_common
+template.significant_factor = settings.significant_factor
+
+for k,v in pairs(template) do
+ template[k] = tostring(v)
end
-local function get_redis_params(ev_base, symbol)
- local redis_params
- local copts = rspamd_config:get_all_opt('classifier')
- if not type(copts) == 'table' then
- logger.errx(ev_base, "Couldn't get classifier configuration")
- return
- end
- if type(copts.backend) == 'table' then
- redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+-- Arguments:
+-- [1] = symbol pattern
+-- [2] = expire value
+-- [3] = cursor
+-- returns new cursor
+local expiry_script = [[
+ local ret = redis.call('SCAN', KEYS[3], 'MATCH', KEYS[1], 'COUNT', '${count}')
+ local next = ret[1]
+ local keys = ret[2]
+ local nelts = 0
+ local extended = 0
+ local discriminated = 0
+
+ for _,key in ipairs(keys) do
+ local values = redis.call('HMGET', key, 'H', 'S')
+ local ham = tonumber(values[1]) or 0
+ local spam = tonumber(values[2]) or 0
+
+ if ham > ${threshold} or spam > ${threshold} then
+ local total = ham + spam
+
+ if total > 0 then
+ if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
+ redis.replicate_commands()
+ redis.call('EXPIRE', key, KEYS[2])
+ extended = extended + 1
+ elseif math.abs(ham - spam) <= total * ${epsilon_common} then
+ local ttl = redis.call('TTL', key)
+ redis.replicate_commands()
+ redis.call('EXPIRE', key, tonumber(ttl) / ${common_ttl_divisor})
+ discriminated = discriminated + 1
+ end
+ end
+ end
+ nelts = nelts + 1
end
- if redis_params then return redis_params end
- if type(copts.statfile) == 'table' then
- for _, stf in ipairs(copts.statfile) do
- if stf.name == symbol then
- redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true)
+
+ return {next, nelts, extended, discriminated}
+]]
+
+local cur = 0
+
+local function expire_step(cls, ev_base, worker)
+
+ local function redis_step_cb(err, data)
+ if err then
+ logger.errx(rspamd_config, 'cannot perform expiry step: %s', err)
+ elseif type(data) == 'table' then
+ local next,nelts,extended,discriminated = tonumber(data[1]), tonumber(data[2]),
+ tonumber(data[3]),tonumber(data[4])
+
+ if next ~= 0 then
+ logger.infox(rspamd_config, 'executed expiry step for bayes: %s items checked, %s extended, %s discriminated',
+ nelts, extended, discriminated)
+ else
+ logger.infox(rspamd_config, 'executed final expiry step for bayes: %s items checked, %s extended, %s discriminated',
+ nelts, extended, discriminated)
end
+
+ cur = next
end
end
- if redis_params then return redis_params end
- redis_params = lredis.rspamd_parse_redis_server(nil, copts, false)
- redis_params.timeout = settings.timeout
- return redis_params
+ lredis.exec_redis_script(cls.script,
+ {ev_base = ev_base, is_write = true},
+ redis_step_cb,
+ {'RS*_*', cls.expiry, cur}
+ )
end
rspamd_config:add_on_load(function (_, ev_base, worker)
- local processed_symbols, expire_script_sha
-- Exit unless we're the first 'controller' worker
if not (worker:get_name() == 'controller' and worker:get_index() == 0) then return end
- -- Persist mempool variable to statefile on shutdown
- rspamd_config:register_finish_script(function ()
- local stamp = pool:get_variable(VAR_NAME, 'double')
- if not stamp then
- logger.warnx(ev_base, 'No last bayes expiry to persist to disk')
- return
- end
- local f, err = io.open(settings['statefile'], 'w')
- if err then
- logger.errx(ev_base, 'Unable to write statefile to disk: %s', err)
- return
- end
- if f then
- f:write(pool:get_variable(VAR_NAME, 'double'))
- f:close()
- end
- end)
- local expire_symbol
- local function load_scripts(redis_params, cont, p1, p2)
- local function load_script_cb(err, data)
- if err then
- logger.errx(ev_base, 'Error loading script: %s', err)
- else
- if type(data) == 'string' then
- expire_script_sha = data
- logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha)
- if type(cont) == 'function' then
- cont(p1, p2)
- end
- end
+
+ local unique_redis_params = {}
+ -- Push redis script to all unique redis servers
+ for _,cls in ipairs(settings.classifiers) do
+ local seen = false
+ for _,rp in ipairs(unique_redis_params) do
+ if lutil.table_cmp(rp, cls.redis_params) then
+ seen = true
end
end
- local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables)
- local ret = lredis.redis_make_request_taskless(ev_base,
- rspamd_config,
- redis_params,
- nil,
- true, -- is write
- load_script_cb, --callback
- 'SCRIPT', -- command
- {'LOAD', scripttxt}
- )
- if not ret then
- logger.errx(ev_base, 'Error loading script')
- end
- end
- local function continue_expire()
- for _, symbol in ipairs(settings.symbols) do
- if not processed_symbols[symbol] then
- local redis_params = get_redis_params(ev_base, symbol)
- if not redis_params then
- processed_symbols[symbol] = true
- logger.errx(ev_base, "Couldn't get redis params")
- else
- load_scripts(redis_params, expire_symbol, redis_params, symbol)
- break
- end
- end
+
+ if not seen then
+ table.insert(unique_redis_params, cls.redis_params)
end
end
- expire_symbol = function(redis_params, symbol)
- local function del_keys_cb(err, data)
- if err then
- logger.errx(ev_base, 'Redis request failed: %s', err)
- end
- processed_symbols[symbol] = true
- continue_expire()
- end
- local function get_keys_cb(err, data)
- if err then
- logger.errx(ev_base, 'Redis request failed: %s', err)
- processed_symbols[symbol] = true
- continue_expire()
- else
- if type(data) == 'string' then
- if data == "" then
- data = {}
- else
- data = lutil.rspamd_str_split(data, string.char(31))
- end
- end
- if type(data) == 'table' then
- if not data[1] then
- logger.warnx(ev_base, 'No keys to delete: %s', symbol)
- processed_symbols[symbol] = true
- continue_expire()
- else
- local ret = lredis.redis_make_request_taskless(ev_base,
- rspamd_config,
- redis_params,
- nil,
- true, -- is write
- del_keys_cb, --callback
- 'DEL', -- command
- data
- )
- if not ret then
- logger.errx(ev_base, 'Redis request failed')
- processed_symbols[symbol] = true
- continue_expire()
- end
- end
- else
- logger.warnx(ev_base, 'No keys to delete: %s', symbol)
- processed_symbols[symbol] = true
- continue_expire()
- end
+
+ for _,rp in ipairs(unique_redis_params) do
+ local script_id = lredis.add_redis_script(lutil.template(expiry_script,
+ template), rp)
+
+ for _,cls in ipairs(settings.classifiers) do
+ if lutil.table_cmp(rp, cls.redis_params) then
+ cls.script = script_id
end
end
- local ret = lredis.redis_make_request_taskless(ev_base,
- rspamd_config,
- redis_params,
- nil,
- false, -- is write
- get_keys_cb, --callback
- 'EVALSHA', -- command
- {expire_script_sha, 0, symbol}
- )
- if not ret then
- logger.errx(ev_base, 'Redis request failed')
- processed_symbols[symbol] = true
- continue_expire()
- end
- end
- local function begin_expire(time)
- local stamp = time or util.get_time()
- pool:set_variable(VAR_NAME, stamp)
- processed_symbols = {}
- continue_expire()
end
+
-- Expire tokens at regular intervals
- local function schedule_regular_expiry()
+ for _,cls in ipairs(settings.classifiers) do
rspamd_config:add_periodic(ev_base, settings['interval'], function ()
- begin_expire()
+ expire_step(cls, ev_base, worker)
return true
end)
end
- -- Expire tokens and reschedule expiry
- local function schedule_intermediate_expiry(when)
- rspamd_config:add_periodic(ev_base, when, function ()
- begin_expire()
- schedule_regular_expiry()
- return false
- end)
- end
- -- Try read statefile on startup
- local stamp
- local f, err = io.open(settings['statefile'], 'r')
- if err then
- logger.warnx(ev_base, 'Failed to open statefile: %s', err)
- end
- if f then
- io.input(f)
- stamp = tonumber(io.read())
- pool:set_variable(VAR_NAME, stamp)
- end
- local time = util.get_time()
- if not stamp then
- logger.debugm(N, ev_base, 'No state found - expiring stats immediately')
- begin_expire(time)
- schedule_regular_expiry()
- return
- end
- local delta = stamp - time + settings['interval']
- if delta <= 0 then
- logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately')
- begin_expire(time)
- schedule_regular_expiry()
- return
- end
- logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta)
- schedule_intermediate_expiry(delta)
end)