From acfd05e29d02e800719bae4d7a2c4fa3c17dc245 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 22 Nov 2016 13:13:36 +0000 Subject: [PATCH] [Fix] Invalidate ANN if training data is incorrect --- src/plugins/lua/fann_redis.lua | 59 +++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index aa4efd4c6..e2ac770dc 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -103,6 +103,19 @@ local redis_lua_script_maybe_invalidate = [[ ]] local redis_maybe_invalidate_sha = nil +-- Lua script to invalidate ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +local redis_lua_script_locked_invalidate = [[ + redis.call('SET', KEYS[1] .. '_version', '0') + redis.call('DEL', KEYS[1] .. '_spam') + redis.call('DEL', KEYS[1] .. '_ham') + redis.call('DEL', KEYS[1] .. '_data') + redis.call('DEL', KEYS[1] .. '_locked') + return 1 +]] +local redis_locked_invalidate_sha = nil + -- Lua script to invalidate ANN from redis -- Uses the following keys -- key1 - prefix for keys @@ -511,10 +524,32 @@ local function train_fann(_, ev_base, elt) create_train_fann(n, elt) end - learning_spawned = true - rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt) - fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base, - {max_epochs = max_epoch, desired_mse = mse}) + if #inputs < max_trains / 2 then + -- Invalidate ANN as it is definitely invalid + local function redis_invalidate_cb(_err, _data) + if _err then + rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err) + elseif type(_data) == 'string' then + rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err) + fanns[id].version = 0 + end + end + -- Invalidate ANN + rspamd_logger.infox('invalidate ANN %s: training data is invalid') + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + redis_invalidate_cb, --callback + 'EVALSHA', -- command + {redis_locked_invalidate_sha, 1, gen_fann_prefix(id)} + ) + else + learning_spawned = true + rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt) + fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base, + {max_epochs = max_epoch, desired_mse = mse}) + end end end @@ -686,6 +721,22 @@ local function load_scripts(cfg, ev_base, on_load_cb) {'LOAD', redis_lua_script_maybe_invalidate} -- arguments ) + local function locked_invalidate_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis locked invalidate script: %s', err) + else + redis_locked_invalidate_sha = tostring(data) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + locked_invalidate_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_locked_invalidate} -- arguments + ) + local function maybe_lock_sha_cb(err, data) if err or not data or type(data) ~= 'string' then rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err) -- 2.39.5