--[[ Copyright (c) 2017, Andrew Lewis Copyright (c) 2017, Vsevolod Stakhov Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ]] -- if confighelp then return end local N = 'bayes_expiry' local logger = require "rspamd_logger" local mempool = require "rspamd_mempool" local util = require "rspamd_util" local lutil = require "lua_util" local lredis = require "lua_redis" local pool = mempool.create() local settings = { interval = 604800, statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'bayes_expired'), variables = { ot_bayes_ttl = 31536000, -- one year ot_min_age = 7776000, -- 90 days ot_min_count = 5, }, symbols = {}, timeout = 60, } local VAR_NAME = 'bayes_expired' local EXPIRE_SCRIPT_TMPL = [[local result = {} local OT_BAYES_TTL = ${ot_bayes_ttl} local OT_MIN_AGE = ${ot_min_age} local OT_MIN_COUNT = ${ot_min_count} local symbol = ARGV[1] local prefixes = redis.call('SMEMBERS', symbol .. '_keys') for _, pfx in ipairs(prefixes) do local res = redis.call('SCAN', '0', 'MATCH', pfx .. '_*') local cursor, data = res[1], res[2] while data do local key_name = table.remove(data) if key_name then local h, s = redis.call('HMGET', key_name, 'H', 'S') if (h or s) then if not s then s = 0 else s = tonumber(s) end if not h then h = 0 else h = tonumber(h) end if s < OT_MIN_COUNT and h < OT_MIN_COUNT then local ttl = redis.call('TTL', key_name) if ttl > 0 then local age = OT_BAYES_TTL - ttl if age > OT_MIN_AGE then table.insert(result, key_name) end end end end else if cursor == "0" then data = nil else local res = redis.call('SCAN', tostring(cursor), 'MATCH', pfx .. '_*') cursor, data = res[1], res[2] end end end end return table.concat(result, string.char(31))]] local function configure_bayes_expiry() local opts = rspamd_config:get_all_opt(N) if not type(opts) == 'table' then return false end for k, v in pairs(opts) do settings[k] = v end if not settings.symbols[1] then logger.warn('No symbols configured, not enabling expiry') return false end return true end if not configure_bayes_expiry() then return end local function get_redis_params(ev_base, symbol) local redis_params local copts = rspamd_config:get_all_opt('classifier') if not type(copts) == 'table' then logger.errx(ev_base, "Couldn't get classifier configuration") return end if type(copts.backend) == 'table' then redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) end if redis_params then return redis_params end if type(copts.statfile) == 'table' then for _, stf in ipairs(copts.statfile) do if stf.name == symbol then redis_params = lredis.rspamd_parse_redis_server(nil, copts.backend, true) end end end if redis_params then return redis_params end redis_params = lredis.rspamd_parse_redis_server(nil, copts, false) redis_params.timeout = settings.timeout return redis_params end rspamd_config:add_on_load(function (_, ev_base, worker) local processed_symbols, expire_script_sha -- Exit unless we're the first 'controller' worker if not (worker:get_name() == 'controller' and worker:get_index() == 0) then return end -- Persist mempool variable to statefile on shutdown rspamd_config:register_finish_script(function () local stamp = pool:get_variable(VAR_NAME, 'double') if not stamp then logger.warnx(ev_base, 'No last bayes expiry to persist to disk') return end local f, err = io.open(settings['statefile'], 'w') if err then logger.errx(ev_base, 'Unable to write statefile to disk: %s', err) return end if f then f:write(pool:get_variable(VAR_NAME, 'double')) f:close() end end) local expire_symbol local function load_scripts(redis_params, cont, p1, p2) local function load_script_cb(err, data) if err then logger.errx(ev_base, 'Error loading script: %s', err) else if type(data) == 'string' then expire_script_sha = data logger.debugm(N, ev_base, 'expire_script_sha: %s', expire_script_sha) if type(cont) == 'function' then cont(p1, p2) end end end end local scripttxt = lutil.template(EXPIRE_SCRIPT_TMPL, settings.variables) local ret = lredis.redis_make_request_taskless(ev_base, rspamd_config, redis_params, nil, true, -- is write load_script_cb, --callback 'SCRIPT', -- command {'LOAD', scripttxt} ) if not ret then logger.errx(ev_base, 'Error loading script') end end local function continue_expire() for _, symbol in ipairs(settings.symbols) do if not processed_symbols[symbol] then local redis_params = get_redis_params(ev_base, symbol) if not redis_params then processed_symbols[symbol] = true logger.errx(ev_base, "Couldn't get redis params") else load_scripts(redis_params, expire_symbol, redis_params, symbol) break end end end end expire_symbol = function(redis_params, symbol) local function del_keys_cb(err, data) if err then logger.errx(ev_base, 'Redis request failed: %s', err) end processed_symbols[symbol] = true continue_expire() end local function get_keys_cb(err, data) if err then logger.errx(ev_base, 'Redis request failed: %s', err) processed_symbols[symbol] = true continue_expire() else if type(data) == 'string' then if data == "" then data = {} else data = lutil.rspamd_str_split(data, string.char(31)) end end if type(data) == 'table' then if not data[1] then logger.warnx(ev_base, 'No keys to delete: %s', symbol) processed_symbols[symbol] = true continue_expire() else local ret = lredis.redis_make_request_taskless(ev_base, rspamd_config, redis_params, nil, true, -- is write del_keys_cb, --callback 'DEL', -- command data ) if not ret then logger.errx(ev_base, 'Redis request failed') processed_symbols[symbol] = true continue_expire() end end else logger.warnx(ev_base, 'No keys to delete: %s', symbol) processed_symbols[symbol] = true continue_expire() end end end local ret = lredis.redis_make_request_taskless(ev_base, rspamd_config, redis_params, nil, false, -- is write get_keys_cb, --callback 'EVALSHA', -- command {expire_script_sha, 0, symbol} ) if not ret then logger.errx(ev_base, 'Redis request failed') processed_symbols[symbol] = true continue_expire() end end local function begin_expire(time) local stamp = time or util.get_time() pool:set_variable(VAR_NAME, stamp) processed_symbols = {} continue_expire() end -- Expire tokens at regular intervals local function schedule_regular_expiry() rspamd_config:add_periodic(ev_base, settings['interval'], function () begin_expire() return true end) end -- Expire tokens and reschedule expiry local function schedule_intermediate_expiry(when) rspamd_config:add_periodic(ev_base, when, function () begin_expire() schedule_regular_expiry() return false end) end -- Try read statefile on startup local stamp local f, err = io.open(settings['statefile'], 'r') if err then logger.warnx(ev_base, 'Failed to open statefile: %s', err) end if f then io.input(f) stamp = tonumber(io.read()) pool:set_variable(VAR_NAME, stamp) end local time = util.get_time() if not stamp then logger.debugm(N, ev_base, 'No state found - expiring stats immediately') begin_expire(time) schedule_regular_expiry() return end local delta = stamp - time + settings['interval'] if delta <= 0 then logger.debugm(N, ev_base, 'Last expiry is too old - expiring stats immediately') begin_expire(time) schedule_regular_expiry() return end logger.debugm(N, ev_base, 'Scheduling next expiry in %s seconds', delta) schedule_intermediate_expiry(delta) end)