local ver = 0
local ret = redis.call('GET', KEYS[1] .. '_version')
if ret then ver = tonumber(ret) end
- if ver > tonumber(KEYS[2]) then return redis.call('GET', KEYS[1] .. '_data') end
+ if ver > tonumber(KEYS[2]) then
+ return {redis.call('GET', KEYS[1] .. '_data'), ret}
+ end
return false
]]
local fann_prefix = 'RFANN'
local max_trains = 1000
-local max_epoch = 100
+local max_epoch = 1000
+local max_usages = 10
local use_settings = false
local watch_interval = 60.0
local mse = 0.0001
end
if fanns[id].fann then
- fanns[id].fann_train = fanns[id].fann
- fanns[id].fann = nil
+ if fanns[id].version % max_usages == 0 then
+ -- Forget last fann
+ rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', id,
+ fanns[id].version)
+ fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+ fanns[id].fann = nil
+ else
+ fanns[id].fann_train = fanns[id].fann
+ fanns[id].fann = nil
+ end
else
fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
fanns[id].version = 0
end
local function load_or_invalidate_fann(data, id, ev_base)
- local err,ann_data = rspamd_util.zstd_decompress(data)
+ local ver = data[2]
+ if not ver or not tonumber(ver) then
+ rspamd_logger.errx(rspamd_config, 'cannot get version for ann: %s', id)
+ return
+ end
+
+ local err,ann_data = rspamd_util.zstd_decompress(data[1])
local ann
if err or not ann_data then
- rspamd_logger.errx(rspamd_config, 'cannot decompress ann: %s', err)
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ann %s: %s', id, err)
+ return
else
ann = rspamd_fann.load_data(ann_data)
end
if is_fann_valid(ann) then
fanns[id].fann = ann
- rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id)
+ rspamd_logger.infox(rspamd_config, 'loaded ann %s version %s from redis',
+ id, ver)
+ fanns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err)
elseif type(_data) == 'string' then
rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err)
+ fanns[id].version = 0
end
end
-- Invalidate ANN
if opts['train']['max_epoch'] then
max_epoch = opts['train']['max_epoch']
end
+ if opts['train']['max_usages'] then
+ max_usages = opts['train']['max_usages']
+ end
+ if opts['train']['mse'] then
+ mse = opts['train']['mse']
+ end
local ret = cfg:register_worker_script("log_helper",
function(score, req_score, results, cf, _id, extra, ev_base)
-- fun.map (snd x) (fun.filter (fst x == module_id) extra)