From d3eb6739d3aa8f542c21679fbbfe44f94cc253cc Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 21 Feb 2017 16:00:40 +0000 Subject: [PATCH] [Fix] Multiple fixes for fann module --- src/plugins/lua/fann_redis.lua | 107 ++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index a9c6a41cb..a0953b00c 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -39,6 +39,7 @@ local fanns = { -- key1 - prefix for fann -- key2 - fann suffix (settings id) -- key3 - spam or ham +-- key4 - maximum trains -- returns 1 or 0: 1 - allow learn, 0 - not allow learn local redis_lua_script_can_train = [[ local prefix = KEYS[1] .. KEYS[2] @@ -46,6 +47,7 @@ local redis_lua_script_can_train = [[ if locked then return 0 end local nspam = 0 local nham = 0 + local lim = tonumber(KEYS[4]) local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2]) if not exists or exists == 0 then @@ -58,9 +60,9 @@ local redis_lua_script_can_train = [[ if ret then nham = tonumber(ret) end if KEYS[3] == 'spam' then - if nham + 1 >= nspam then return tostring(nspam + 1) end + if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) end else - if nspam + 1 >= nham then return tostring(nham + 1) end + if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) end end return tostring(0) @@ -344,20 +346,20 @@ local function gen_fann_prefix(id) end end -local function is_fann_valid(ann) +local function is_fann_valid(prefix, ann) if ann then local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() if n ~= ann:get_inputs() then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', ann:get_inputs(), n) + rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', prefix, ann:get_inputs(), n) return false end local layers = ann:get_layers() if not layers or #layers ~= nlayers then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s', - #layers) + rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s', + prefix, #layers) return false end @@ -399,6 +401,7 @@ end local function create_train_fann(n, id) id = tostring(id) + local prefix = gen_fann_prefix(id) if not fanns[id] then fanns[id] = {} end @@ -406,13 +409,13 @@ local function create_train_fann(n, id) if fanns[id].fann then if n ~= fanns[id].fann:get_inputs() or (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then - rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', id, + rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix, fanns[id].version) fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) fanns[id].fann = nil elseif fanns[id].version % max_usages == 0 then -- Forget last fann - rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', id, + rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, fanns[id].version) fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) else @@ -426,8 +429,10 @@ end local function load_or_invalidate_fann(data, id, ev_base) local ver = data[2] + local prefix = gen_fann_prefix(id) + if not ver or not tonumber(ver) then - rspamd_logger.errx(rspamd_config, 'cannot get version for ann: %s', id) + rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix) return end @@ -435,38 +440,38 @@ local function load_or_invalidate_fann(data, id, ev_base) local ann if err or not ann_data then - rspamd_logger.errx(rspamd_config, 'cannot decompress ann %s: %s', id, err) + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err) return else ann = rspamd_fann.load_data(ann_data) end - if is_fann_valid(ann) then + if is_fann_valid(prefix, ann) then fanns[id].fann = ann - rspamd_logger.infox(rspamd_config, 'loaded ann %s version %s from redis', - id, ver) + rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', + prefix, ver) fanns[id].version = tonumber(ver) else local function redis_invalidate_cb(_err, _data) if _err then - rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err) + rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) if string.match(_err, 'NOSCRIPT') then load_scripts(rspamd_config, ev_base, nil) end elseif type(_data) == 'string' then - rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err) + rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) fanns[id].version = 0 end end -- Invalidate ANN - rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', id) + rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix) redis_make_request(ev_base, rspamd_config, nil, true, -- is write redis_invalidate_cb, --callback 'EVALSHA', -- command - {redis_maybe_invalidate_sha, 1, gen_fann_prefix(id)} + {redis_maybe_invalidate_sha, 1, prefix} ) end end @@ -493,7 +498,7 @@ local function fann_train_callback(score, required_score, results, _, id, opts, local function learn_vec_cb(err) if err then - rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err) + rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err) end end @@ -517,7 +522,7 @@ local function fann_train_callback(score, required_score, results, _, id, opts, ) else if err then - rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err) + rspamd_logger.errx(rspamd_config, 'cannot check if we can train %s: %s', fname, err) if string.match(err, 'NOSCRIPT') then load_scripts(rspamd_config, ev_base, nil) end @@ -531,7 +536,7 @@ local function fann_train_callback(score, required_score, results, _, id, opts, true, -- is write can_train_cb, --callback 'EVALSHA', -- command - {redis_can_train_sha, '3', gen_fann_prefix(nil), suffix, k} -- arguments + {redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments ) end end @@ -540,25 +545,26 @@ local function train_fann(_, ev_base, elt) local spam_elts = {} local ham_elts = {} elt = tostring(elt) + local prefix = gen_fann_prefix(elt) local function redis_unlock_cb(err) if err then rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s', - gen_fann_prefix(elt), err) + prefix, err) end end local function redis_save_cb(err) if err then rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s', - gen_fann_prefix(elt), err) + prefix, err) redis_make_request(ev_base, rspamd_config, nil, false, -- is write redis_unlock_cb, --callback 'DEL', -- command - {gen_fann_prefix(elt) .. '_locked'} + {prefix .. '_locked'} ) if string.match(err, 'NOSCRIPT') then load_scripts(rspamd_config, ev_base, nil) @@ -570,18 +576,18 @@ local function train_fann(_, ev_base, elt) learning_spawned = false if errcode ~= 0 then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', - gen_fann_prefix(elt), errmsg) + prefix, errmsg) redis_make_request(ev_base, rspamd_config, nil, true, -- is write redis_unlock_cb, --callback 'DEL', -- command - {gen_fann_prefix(elt) .. '_locked'} + {prefix .. '_locked'} ) else rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', - gen_fann_prefix(elt), train_mse) + prefix, train_mse) local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data()) fanns[elt].version = fanns[elt].version + 1 fanns[elt].fann = fanns[elt].fann_train @@ -592,7 +598,7 @@ local function train_fann(_, ev_base, elt) true, -- is write redis_save_cb, --callback 'EVALSHA', -- command - {redis_save_unlock_sha, '2', gen_fann_prefix(elt), ann_data} + {redis_save_unlock_sha, '2', prefix, ann_data} ) end end @@ -600,14 +606,14 @@ local function train_fann(_, ev_base, elt) local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', - gen_fann_prefix(elt), err) + prefix, err) redis_make_request(ev_base, rspamd_config, nil, true, -- is write redis_unlock_cb, --callback 'DEL', -- command - {gen_fann_prefix(elt) .. '_locked'} + {prefix .. '_locked'} ) else -- Decompress and convert to numbers each training vector @@ -643,25 +649,25 @@ local function train_fann(_, ev_base, elt) -- Invalidate ANN as it is definitely invalid local function redis_invalidate_cb(_err, _data) if _err then - rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', elt, _err) + rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) elseif type(_data) == 'string' then - rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', elt, _err) + rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) fanns[elt].version = 0 end end -- Invalidate ANN - rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', elt) + rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix) redis_make_request(ev_base, rspamd_config, nil, true, -- is write redis_invalidate_cb, --callback 'EVALSHA', -- command - {redis_locked_invalidate_sha, 1, gen_fann_prefix(elt)} + {redis_locked_invalidate_sha, 1, prefix} ) else learning_spawned = true - rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt) + rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base, {max_epochs = max_epoch, desired_mse = mse}) end @@ -671,14 +677,14 @@ local function train_fann(_, ev_base, elt) local function redis_spam_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', - gen_fann_prefix(elt), err) + prefix, err) redis_make_request(ev_base, rspamd_config, nil, true, -- is write redis_unlock_cb, --callback 'DEL', -- command - {gen_fann_prefix(elt) .. '_locked'} + {prefix .. '_locked'} ) else -- Decompress and convert to numbers each training vector @@ -692,7 +698,7 @@ local function train_fann(_, ev_base, elt) false, -- is write redis_ham_cb, --callback 'LRANGE', -- command - {gen_fann_prefix(elt) .. '_ham', '0', '-1'} + {prefix .. '_ham', '0', '-1'} ) end end @@ -700,7 +706,7 @@ local function train_fann(_, ev_base, elt) local function redis_lock_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', - gen_fann_prefix(elt), err) + prefix, err) if string.match(err, 'NOSCRIPT') then load_scripts(rspamd_config, ev_base, nil) end @@ -712,7 +718,7 @@ local function train_fann(_, ev_base, elt) false, -- is write redis_spam_cb, --callback 'LRANGE', -- command - {gen_fann_prefix(elt) .. '_spam', '0', '-1'} + {prefix .. '_spam', '0', '-1'} ) rspamd_config:add_periodic(ev_base, 30.0, @@ -720,10 +726,10 @@ local function train_fann(_, ev_base, elt) local function redis_lock_extend_cb(_err, _) if _err then rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', - gen_fann_prefix(elt), _err) + prefix, _err) else rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', - gen_fann_prefix(elt)) + prefix) end end if learning_spawned then @@ -733,7 +739,7 @@ local function train_fann(_, ev_base, elt) true, -- is write redis_lock_extend_cb, --callback 'INCRBY', -- command - {gen_fann_prefix(elt) .. '_locked', '30'} + {prefix .. '_locked', '30'} ) else return false -- do not plan any more updates @@ -742,13 +748,13 @@ local function train_fann(_, ev_base, elt) return true end ) - rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', elt) + rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', prefix) else - rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', elt) + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix) end end if learning_spawned then - rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN') + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix) return end redis_make_request(ev_base, @@ -757,7 +763,7 @@ local function train_fann(_, ev_base, elt) true, -- is write redis_lock_cb, --callback 'EVALSHA', -- command - {redis_maybe_lock_sha, '4', gen_fann_prefix(elt), tostring(os.time()), + {redis_maybe_lock_sha, '4', prefix, tostring(os.time()), tostring(lock_expire), rspamd_util.get_hostname()} ) end @@ -769,13 +775,14 @@ local function maybe_train_fanns(cfg, ev_base) elseif type(data) == 'table' then fun.each(function(elt) elt = tostring(elt) + local prefix = gen_fann_prefix(elt) local redis_len_cb = function(_err, _data) if _err then - rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, _err) + rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', prefix, _err) elseif _data and type(_data) == 'number' or type(_data) == 'string' then if tonumber(_data) and tonumber(_data) > max_trains then rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)', - elt, tonumber(_data), max_trains) + prefix, tonumber(_data), max_trains) train_fann(cfg, ev_base, elt) end end @@ -787,7 +794,7 @@ local function maybe_train_fanns(cfg, ev_base) false, -- is write redis_len_cb, --callback 'LLEN', -- command - {gen_fann_prefix(elt) .. '_spam'} + {prefix .. '_spam'} ) end, data) -- 2.39.5