-- key1 - prefix for fann
-- key2 - fann suffix (settings id)
-- key3 - spam or ham
+-- key4 - maximum trains
-- returns 1 or 0: 1 - allow learn, 0 - not allow learn
local redis_lua_script_can_train = [[
local prefix = KEYS[1] .. KEYS[2]
if locked then return 0 end
local nspam = 0
local nham = 0
+ local lim = tonumber(KEYS[4])
local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2])
if not exists or exists == 0 then
if ret then nham = tonumber(ret) end
if KEYS[3] == 'spam' then
- if nham + 1 >= nspam then return tostring(nspam + 1) end
+ if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) end
else
- if nspam + 1 >= nham then return tostring(nham + 1) end
+ if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) end
end
return tostring(0)
end
end
-local function is_fann_valid(ann)
+local function is_fann_valid(prefix, ann)
if ann then
local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
if n ~= ann:get_inputs() then
- rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', ann:get_inputs(), n)
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', prefix, ann:get_inputs(), n)
return false
end
local layers = ann:get_layers()
if not layers or #layers ~= nlayers then
- rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
- #layers)
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+ prefix, #layers)
return false
end
local function create_train_fann(n, id)
id = tostring(id)
+ local prefix = gen_fann_prefix(id)
if not fanns[id] then
fanns[id] = {}
end
if fanns[id].fann then
if n ~= fanns[id].fann:get_inputs() or
(fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then
- rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', id,
+ rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix,
fanns[id].version)
fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
fanns[id].fann = nil
elseif fanns[id].version % max_usages == 0 then
-- Forget last fann
- rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', id,
+ rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
fanns[id].version)
fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
else
local function load_or_invalidate_fann(data, id, ev_base)
local ver = data[2]
+ local prefix = gen_fann_prefix(id)
+
if not ver or not tonumber(ver) then
- rspamd_logger.errx(rspamd_config, 'cannot get version for ann: %s', id)
+ rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
return
end
local ann
if err or not ann_data then
- rspamd_logger.errx(rspamd_config, 'cannot decompress ann %s: %s', id, err)
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
return
else
ann = rspamd_fann.load_data(ann_data)
end
- if is_fann_valid(ann) then
+ if is_fann_valid(prefix, ann) then
fanns[id].fann = ann
- rspamd_logger.infox(rspamd_config, 'loaded ann %s version %s from redis',
- id, ver)
+ rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
+ prefix, ver)
fanns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
- rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err)
+ rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
if string.match(_err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
end
elseif type(_data) == 'string' then
- rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err)
+ rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
fanns[id].version = 0
end
end
-- Invalidate ANN
- rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', id)
+ rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
redis_make_request(ev_base,
rspamd_config,
nil,
true, -- is write
redis_invalidate_cb, --callback
'EVALSHA', -- command
- {redis_maybe_invalidate_sha, 1, gen_fann_prefix(id)}
+ {redis_maybe_invalidate_sha, 1, prefix}
)
end
end
local function learn_vec_cb(err)
if err then
- rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err)
+ rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err)
end
end
)
else
if err then
- rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err)
+ rspamd_logger.errx(rspamd_config, 'cannot check if we can train %s: %s', fname, err)
if string.match(err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
end
true, -- is write
can_train_cb, --callback
'EVALSHA', -- command
- {redis_can_train_sha, '3', gen_fann_prefix(nil), suffix, k} -- arguments
+ {redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments
)
end
end
local spam_elts = {}
local ham_elts = {}
elt = tostring(elt)
+ local prefix = gen_fann_prefix(elt)
local function redis_unlock_cb(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
- gen_fann_prefix(elt), err)
+ prefix, err)
end
end
local function redis_save_cb(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
- gen_fann_prefix(elt), err)
+ prefix, err)
redis_make_request(ev_base,
rspamd_config,
nil,
false, -- is write
redis_unlock_cb, --callback
'DEL', -- command
- {gen_fann_prefix(elt) .. '_locked'}
+ {prefix .. '_locked'}
)
if string.match(err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
learning_spawned = false
if errcode ~= 0 then
rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
- gen_fann_prefix(elt), errmsg)
+ prefix, errmsg)
redis_make_request(ev_base,
rspamd_config,
nil,
true, -- is write
redis_unlock_cb, --callback
'DEL', -- command
- {gen_fann_prefix(elt) .. '_locked'}
+ {prefix .. '_locked'}
)
else
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
- gen_fann_prefix(elt), train_mse)
+ prefix, train_mse)
local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
fanns[elt].version = fanns[elt].version + 1
fanns[elt].fann = fanns[elt].fann_train
true, -- is write
redis_save_cb, --callback
'EVALSHA', -- command
- {redis_save_unlock_sha, '2', gen_fann_prefix(elt), ann_data}
+ {redis_save_unlock_sha, '2', prefix, ann_data}
)
end
end
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
- gen_fann_prefix(elt), err)
+ prefix, err)
redis_make_request(ev_base,
rspamd_config,
nil,
true, -- is write
redis_unlock_cb, --callback
'DEL', -- command
- {gen_fann_prefix(elt) .. '_locked'}
+ {prefix .. '_locked'}
)
else
-- Decompress and convert to numbers each training vector
-- Invalidate ANN as it is definitely invalid
local function redis_invalidate_cb(_err, _data)
if _err then
- rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', elt, _err)
+ rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
- rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', elt, _err)
+ rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
fanns[elt].version = 0
end
end
-- Invalidate ANN
- rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', elt)
+ rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
redis_make_request(ev_base,
rspamd_config,
nil,
true, -- is write
redis_invalidate_cb, --callback
'EVALSHA', -- command
- {redis_locked_invalidate_sha, 1, gen_fann_prefix(elt)}
+ {redis_locked_invalidate_sha, 1, prefix}
)
else
learning_spawned = true
- rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt)
+ rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
{max_epochs = max_epoch, desired_mse = mse})
end
local function redis_spam_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
- gen_fann_prefix(elt), err)
+ prefix, err)
redis_make_request(ev_base,
rspamd_config,
nil,
true, -- is write
redis_unlock_cb, --callback
'DEL', -- command
- {gen_fann_prefix(elt) .. '_locked'}
+ {prefix .. '_locked'}
)
else
-- Decompress and convert to numbers each training vector
false, -- is write
redis_ham_cb, --callback
'LRANGE', -- command
- {gen_fann_prefix(elt) .. '_ham', '0', '-1'}
+ {prefix .. '_ham', '0', '-1'}
)
end
end
local function redis_lock_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
- gen_fann_prefix(elt), err)
+ prefix, err)
if string.match(err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
end
false, -- is write
redis_spam_cb, --callback
'LRANGE', -- command
- {gen_fann_prefix(elt) .. '_spam', '0', '-1'}
+ {prefix .. '_spam', '0', '-1'}
)
rspamd_config:add_periodic(ev_base, 30.0,
local function redis_lock_extend_cb(_err, _)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
- gen_fann_prefix(elt), _err)
+ prefix, _err)
else
rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
- gen_fann_prefix(elt))
+ prefix)
end
end
if learning_spawned then
true, -- is write
redis_lock_extend_cb, --callback
'INCRBY', -- command
- {gen_fann_prefix(elt) .. '_locked', '30'}
+ {prefix .. '_locked', '30'}
)
else
return false -- do not plan any more updates
return true
end
)
- rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', elt)
+ rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', prefix)
else
- rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', elt)
+ rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix)
end
end
if learning_spawned then
- rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN')
+ rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
return
end
redis_make_request(ev_base,
true, -- is write
redis_lock_cb, --callback
'EVALSHA', -- command
- {redis_maybe_lock_sha, '4', gen_fann_prefix(elt), tostring(os.time()),
+ {redis_maybe_lock_sha, '4', prefix, tostring(os.time()),
tostring(lock_expire), rspamd_util.get_hostname()}
)
end
elseif type(data) == 'table' then
fun.each(function(elt)
elt = tostring(elt)
+ local prefix = gen_fann_prefix(elt)
local redis_len_cb = function(_err, _data)
if _err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, _err)
+ rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', prefix, _err)
elseif _data and type(_data) == 'number' or type(_data) == 'string' then
if tonumber(_data) and tonumber(_data) > max_trains then
rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
- elt, tonumber(_data), max_trains)
+ prefix, tonumber(_data), max_trains)
train_fann(cfg, ev_base, elt)
end
end
false, -- is write
redis_len_cb, --callback
'LLEN', -- command
- {gen_fann_prefix(elt) .. '_spam'}
+ {prefix .. '_spam'}
)
end,
data)