diff options
Diffstat (limited to 'src/plugins/lua/bayes_expiry.lua')
-rw-r--r-- | src/plugins/lua/bayes_expiry.lua | 420 |
1 files changed, 186 insertions, 234 deletions
diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua index d922f3f55..af955465d 100644 --- a/src/plugins/lua/bayes_expiry.lua +++ b/src/plugins/lua/bayes_expiry.lua @@ -20,280 +20,232 @@ if confighelp then end local N = 'bayes_expiry' +local E = {} local logger = require "rspamd_logger" -local mempool = require "rspamd_mempool" -local util = require "rspamd_util" local lutil = require "lua_util" local lredis = require "lua_redis" -local pool = mempool.create() local settings = { - interval = 604800, - statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'), - variables = { - ot_bayes_ttl = 31536000, -- one year - ot_min_age = 7776000, -- 90 days - ot_min_count = 5, - }, - symbols = {}, - timeout = 60, + interval = 60, -- one iteration step per minute + count = 1000, -- check up to 1000 keys on each iteration + threshold = 10, -- require at least 10 occurrences to increase expire + epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon + common_ttl_divisor = 100, -- how should we discriminate common elements + significant_factor = 3.0 / 4.0, -- which tokens should we update + classifiers = {}, } -local VAR_NAME = 'bayes_expired' -local EXPIRE_SCRIPT_TMPL = [[local result = {} -local OT_BAYES_TTL = ${ot_bayes_ttl} -local OT_MIN_AGE = ${ot_min_age} -local OT_MIN_COUNT = ${ot_min_count} -local symbol = ARGV[1] -local prefixes = redis.call('SMEMBERS', symbol .. '_keys') -for _, pfx in ipairs(prefixes) do - local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*') - local cursor, data = res[1], res[2] - while data do - local key_name = table.remove(data) - if key_name then - local h, s = redis.call('HMGET', key_name, 'H', 'S') - if (h or s) then - if not s then s = 0 else s = tonumber(s) end - if not h then h = 0 else h = tonumber(h) end - if s < OT_MIN_COUNT and h < OT_MIN_COUNT then - local ttl = redis.call('TTL', key_name) - if ttl > 0 then - local age = OT_BAYES_TTL - ttl - if age > OT_MIN_AGE then - table.insert(result, key_name) - end - end +local template = { + +} + +local function check_redis_classifier(cls, cfg) + -- Skip old classifiers + if cls.new_schema then + local symbol_spam, symbol_ham + local expiry = (cls.expiry or cls.expire) + -- Load symbols from statfiles + local statfiles = cls.statfile + for _,stf in ipairs(statfiles) do + local symbol = stf.symbol or 'undefined' + + local spam + if stf.spam then + spam = stf.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false end end - else - if cursor == "0" then - data = nil + + if spam then + symbol_spam = symbol else - local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*') - cursor, data = res[1], res[2] + symbol_ham = symbol + end + end + + if not symbol_spam or not symbol_ham or not expiry then + return + end + -- Now try to load redis_params if needed + + local redis_params = {} + if not lredis.try_load_redis_servers(cls, rspamd_config, redis_params) then + if not lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, redis_params) then + if not lredis.try_load_redis_servers(cfg['redis'] or E, rspamd_config, redis_params) then + return false + end end end + + table.insert(settings.classifiers, { + symbol_spam = symbol_spam, + symbol_ham = symbol_ham, + redis_params = redis_params, + expiry = expiry + }) end end -return table.concat(result, string.char(31))]] -local function configure_bayes_expiry() - local opts = rspamd_config:get_all_opt(N) - if not type(opts) == 'table' then return false end - for k, v in pairs(opts) do - settings[k] = v +-- Check classifiers and try find the appropriate ones +local obj = rspamd_config:get_ucl() + +local classifier = obj.classifier + +if classifier then + if classifier[1] then + for _,cls in ipairs(classifier) do + if cls.bayes then cls = cls.bayes end + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.bayes then + + classifier = classifier.bayes + if classifier[1] then + for _,cls in ipairs(classifier) do + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.backend and classifier.backend == 'redis' then + check_redis_classifier(classifier, obj) + end + end + end end - if not settings.symbols[1] then - logger.warn('No symbols configured, not enabling expiry') - return false +end + + +local opts = rspamd_config:get_all_opt(N) + +if opts then + for k,v in pairs(opts) do + settings[k] = v end - return true end -if not configure_bayes_expiry() then - lutil.disable_module(N, 'config') - return +-- Fill template +template.count = settings.count +template.threshold = settings.threshold +template.common_ttl_divisor = settings.common_ttl_divisor +template.epsilon_common = settings.epsilon_common +template.significant_factor = settings.significant_factor + +for k,v in pairs(template) do + template[k] = tostring(v) end -local function get_redis_params(ev_base, symbol) - local redis_params - local copts = rspamd_config:get_all_opt('classifier') - if not type(copts) == 'table' then - logger.errx(ev_base, "Couldn't get classifier configuration") - return - end - if type(copts.backend) == 'table' then - redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) +-- Arguments: +-- [1] = symbol pattern +-- [2] = expire value +-- [3] = cursor +-- returns new cursor +local expiry_script = [[ + local ret = redis.call('SCAN', KEYS[3], 'MATCH', KEYS[1], 'COUNT', '${count}') + local next = ret[1] + local keys = ret[2] + local nelts = 0 + local extended = 0 + local discriminated = 0 + + for _,key in ipairs(keys) do + local values = redis.call('HMGET', key, 'H', 'S') + local ham = tonumber(values[1]) or 0 + local spam = tonumber(values[2]) or 0 + + if ham > ${threshold} or spam > ${threshold} then + local total = ham + spam + + if total > 0 then + if ham / total > ${significant_factor} or spam / total > ${significant_factor} then + redis.replicate_commands() + redis.call('EXPIRE', key, KEYS[2]) + extended = extended + 1 + elseif math.abs(ham - spam) <= total * ${epsilon_common} then + local ttl = redis.call('TTL', key) + redis.replicate_commands() + redis.call('EXPIRE', key, tonumber(ttl) / ${common_ttl_divisor}) + discriminated = discriminated + 1 + end + end + end + nelts = nelts + 1 end - if redis_params then return redis_params end - if type(copts.statfile) == 'table' then - for _, stf in ipairs(copts.statfile) do - if stf.name == symbol then - redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) + + return {next, nelts, extended, discriminated} +]] + +local cur = 0 + +local function expire_step(cls, ev_base, worker) + + local function redis_step_cb(err, data) + if err then + logger.errx(rspamd_config, 'cannot perform expiry step: %s', err) + elseif type(data) == 'table' then + local next,nelts,extended,discriminated = tonumber(data[1]), tonumber(data[2]), + tonumber(data[3]),tonumber(data[4]) + + if next ~= 0 then + logger.infox(rspamd_config, 'executed expiry step for bayes: %s items checked, %s extended, %s discriminated', + nelts, extended, discriminated) + else + logger.infox(rspamd_config, 'executed final expiry step for bayes: %s items checked, %s extended, %s discriminated', + nelts, extended, discriminated) end + + cur = next end end - if redis_params then return redis_params end - redis_params = lredis.rspamd_parse_redis_server(nil, copts, false) - redis_params.timeout = settings.timeout - return redis_params + lredis.exec_redis_script(cls.script, + {ev_base = ev_base, is_write = true}, + redis_step_cb, + {'RS*_*', cls.expiry, cur} + ) end rspamd_config:add_on_load(function (_, ev_base, worker) - local processed_symbols, expire_script_sha -- Exit unless we're the first 'controller' worker if not (worker:get_name() == 'controller' and worker:get_index() == 0) then return end - -- Persist mempool variable to statefile on shutdown - rspamd_config:register_finish_script(function () - local stamp = pool:get_variable(VAR_NAME, 'double') - if not stamp then - logger.warnx(ev_base, 'No last bayes expiry to persist to disk') - return - end - local f, err = io.open(settings['statefile'], 'w') - if err then - logger.errx(ev_base, 'Unable to write statefile to disk: %s', err) - return - end - if f then - f:write(pool:get_variable(VAR_NAME, 'double')) - f:close() - end - end) - local expire_symbol - local function load_scripts(redis_params, cont, p1, p2) - local function load_script_cb(err, data) - if err then - logger.errx(ev_base, 'Error loading script: %s', err) - else - if type(data) == 'string' then - expire_script_sha = data - logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha) - if type(cont) == 'function' then - cont(p1, p2) - end - end + + local unique_redis_params = {} + -- Push redis script to all unique redis servers + for _,cls in ipairs(settings.classifiers) do + local seen = false + for _,rp in ipairs(unique_redis_params) do + if lutil.table_cmp(rp, cls.redis_params) then + seen = true end end - local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables) - local ret = lredis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - load_script_cb, --callback - 'SCRIPT', -- command - {'LOAD', scripttxt} - ) - if not ret then - logger.errx(ev_base, 'Error loading script') - end - end - local function continue_expire() - for _, symbol in ipairs(settings.symbols) do - if not processed_symbols[symbol] then - local redis_params = get_redis_params(ev_base, symbol) - if not redis_params then - processed_symbols[symbol] = true - logger.errx(ev_base, "Couldn't get redis params") - else - load_scripts(redis_params, expire_symbol, redis_params, symbol) - break - end - end + + if not seen then + table.insert(unique_redis_params, cls.redis_params) end end - expire_symbol = function(redis_params, symbol) - local function del_keys_cb(err, data) - if err then - logger.errx(ev_base, 'Redis request failed: %s', err) - end - processed_symbols[symbol] = true - continue_expire() - end - local function get_keys_cb(err, data) - if err then - logger.errx(ev_base, 'Redis request failed: %s', err) - processed_symbols[symbol] = true - continue_expire() - else - if type(data) == 'string' then - if data == "" then - data = {} - else - data = lutil.rspamd_str_split(data, string.char(31)) - end - end - if type(data) == 'table' then - if not data[1] then - logger.warnx(ev_base, 'No keys to delete: %s', symbol) - processed_symbols[symbol] = true - continue_expire() - else - local ret = lredis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - del_keys_cb, --callback - 'DEL', -- command - data - ) - if not ret then - logger.errx(ev_base, 'Redis request failed') - processed_symbols[symbol] = true - continue_expire() - end - end - else - logger.warnx(ev_base, 'No keys to delete: %s', symbol) - processed_symbols[symbol] = true - continue_expire() - end + + for _,rp in ipairs(unique_redis_params) do + local script_id = lredis.add_redis_script(lutil.template(expiry_script, + template), rp) + + for _,cls in ipairs(settings.classifiers) do + if lutil.table_cmp(rp, cls.redis_params) then + cls.script = script_id end end - local ret = lredis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - false, -- is write - get_keys_cb, --callback - 'EVALSHA', -- command - {expire_script_sha, 0, symbol} - ) - if not ret then - logger.errx(ev_base, 'Redis request failed') - processed_symbols[symbol] = true - continue_expire() - end - end - local function begin_expire(time) - local stamp = time or util.get_time() - pool:set_variable(VAR_NAME, stamp) - processed_symbols = {} - continue_expire() end + -- Expire tokens at regular intervals - local function schedule_regular_expiry() + for _,cls in ipairs(settings.classifiers) do rspamd_config:add_periodic(ev_base, settings['interval'], function () - begin_expire() + expire_step(cls, ev_base, worker) return true end) end - -- Expire tokens and reschedule expiry - local function schedule_intermediate_expiry(when) - rspamd_config:add_periodic(ev_base, when, function () - begin_expire() - schedule_regular_expiry() - return false - end) - end - -- Try read statefile on startup - local stamp - local f, err = io.open(settings['statefile'], 'r') - if err then - logger.warnx(ev_base, 'Failed to open statefile: %s', err) - end - if f then - io.input(f) - stamp = tonumber(io.read()) - pool:set_variable(VAR_NAME, stamp) - end - local time = util.get_time() - if not stamp then - logger.debugm(N, ev_base, 'No state found - expiring stats immediately') - begin_expire(time) - schedule_regular_expiry() - return - end - local delta = stamp - time + settings['interval'] - if delta <= 0 then - logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately') - begin_expire(time) - schedule_regular_expiry() - return - end - logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta) - schedule_intermediate_expiry(delta) end) |