if ret then nham = tonumber(ret) end
if KEYS[3] == 'spam' then
- if nham + 1 >= nspam then return tostring(nspam) end
+ if nham + 1 >= nspam then return tostring(nspam + 1) end
else
- if nspam + 1 >= nham then return tostring(nham) end
+ if nspam + 1 >= nham then return tostring(nham + 1) end
end
return tostring(0)
return false
]]
-local redis_fann_maybe_load_sha = nil
+local redis_maybe_load_sha = nil
+
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+local redis_lua_script_maybe_invalidate = [[
+ local locked = redis.call('GET', KEYS[1] .. '_locked')
+ if locked then return false end
+ redis.call('SET', KEYS[1] .. '_locked', '1')
+ redis.call('SET', KEYS[1] .. '_version', '0')
+ redis.call('DEL', KEYS[1] .. '_spam')
+ redis.call('DEL', KEYS[1] .. '_ham')
+ redis.call('DEL', KEYS[1] .. '_data')
+ redis.call('DEL', KEYS[1] .. '_locked')
+ return 1
+]]
+local redis_maybe_invalidate_sha = nil
local redis_params
redis_params = rspamd_parse_redis_server('fann_scores')
-local fann_prefix = 'RF'
+local fann_prefix = 'RFANN'
local max_trains = 1000
local max_epoch = 100
local use_settings = false
end
end
-local function is_fann_valid(id)
- if data[id].fann then
+local function is_fann_valid(ann)
+ if ann then
local n = rspamd_config:get_symbols_count() + count_metatokens()
- if n ~= data[id].fann:get_inputs() then
+ if n ~= ann:get_inputs() then
rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', data[id].fann:get_inputs(), n)
- data[id].fann = nil
+ ' is found in the cache', ann:get_inputs(), n)
+ return false
end
- local layers = data[id].fann:get_layers()
+ local layers = ann:get_layers()
if not layers or #layers ~= 5 then
rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
#layers)
- data[id].fann = nil
+ return false
end
- end
- if data[id].fann then
return true
end
-
- return false
end
local function fann_scores_filter(task)
data[id].epoch = 0
end
+local function load_or_invalidate_fann(data, id, ev_base)
+ local err,ann_data = rspamd_util.zstd_decompress(data)
+ local ann
+
+ if err or not ann_data then
+ rspamd_logger.errx('cannot decompress ann: %s', err)
+ else
+ ann = rspamd_fann.load_data(ann_data)
+ end
+
+ if is_fann_valid(ann) then
+ data[id].fann = ann
+ else
+ local function redis_invalidate_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err)
+ elseif type(data) == 'string' then
+ rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err)
+ end
+ end
+ -- Invalidate ANN
+ rspamd_logger.infox('invalidate ANN %s')
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ redis_invalidate_cb, --callback
+ 'EVALSHA', -- command
+ {redis_maybe_invalidate_sha, 1, fann_prefix .. id}
+ )
+ end
+end
+
local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base)
local fname = gen_fann_prefix(id)
local k
if learn_spam then k = 'spam' else k = 'ham' end
+ local function learn_vec_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err)
+ end
+ end
+
local function can_train_cb(err, data)
+ rspamd_logger.errx('data: %s, err: %s', data, err)
if not err and tonumber(data) > 0 then
local learn_data = symbols_to_fann_vector(
map(function(r) return r[1] end, results),
)
-- Add filtered meta tokens
each(function(e) table.insert(learn_data, e) end, extra)
- local str = table.concat(learn_data, ';')
+ local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
redis_make_request(ev_base,
rspamd_config,
nil,
true, -- is write
- learn_cb, --callback
+ learn_vec_cb, --callback
'LPUSH', -- command
{fname .. '_' .. k, str} -- arguments
)
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
elseif type(data) == 'table' then
each(function(i, elt)
- local redis_load_cb = function(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
- elseif type(data) == 'string' then
- --load_fann(data, elt)
- end
- end
local redis_update_cb = function(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
- elseif data then
- redis_make_request(ev_base,
- rspamd_config,
- nil,
- false, -- is write
- redis_load_cb, --callback
- 'GET', -- command
- {fann_prefix, fann_prefix .. elt .. '_data'} -- arguments
- )
+ elseif data and type(data) == 'string' then
+ load_or_invalidate_fann(data, elt, ev_base)
end
end
false, -- is write
redis_update_cb, --callback
'EVALSHA', -- command
- {redis_fann_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)}
+ {redis_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)}
)
end,
data)
end
end
- if not redis_fann_maybe_load_sha then
+ if not redis_maybe_load_sha then
-- Plan new event early
return 1.0
end
if err or not data or type(data) ~= 'string' then
rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err)
else
- redis_fann_maybe_load_sha = tostring(data)
+ redis_maybe_load_sha = tostring(data)
rspamd_config:add_periodic(ev_base, 0.0,
function(cfg, ev_base)
'SCRIPT', -- command
{'LOAD', redis_lua_script_maybe_load} -- arguments
)
+
+ local function maybe_invalidate_sha_cb(err, data)
+ if err or not data or type(data) ~= 'string' then
+ rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err)
+ else
+ redis_maybe_invalidate_sha = tostring(data)
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ maybe_invalidate_sha_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
+ )
end)
end