From 6576616d3fe319148a29dc6d0ef0222b24384fd0 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 16 Sep 2017 15:56:28 +0100 Subject: [PATCH] [Fix] Fix ANN checks --- src/plugins/lua/fann_redis.lua | 44 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 2473fb290..f7ec65d30 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -43,6 +43,7 @@ local default_options = { max_trains = 1000, max_epoch = 1000, max_usages = 10, + max_iterations = 25, -- Torch style mse = 0.001, autotrain = true, }, @@ -331,19 +332,7 @@ local function is_fann_valid(rule, prefix, ann) meta_functions.rspamd_count_metatokens() if torch then - local nlayers = #ann - if nlayers ~= rule.nlayers then - rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s', - prefix, nlayers) - return false - end - - local inp = ann:get(1):nElement() - if n ~= inp then - rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', prefix, inp, n) - return false - end + return true else if n ~= ann:get_inputs() then rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' .. @@ -364,12 +353,13 @@ local function is_fann_valid(rule, prefix, ann) end local function fann_scores_filter(task) - for _,rule in ipairs(settings.rules) do - local id = rule.prefix .. '0' + + for _,rule in pairs(settings.rules) do + local id = '0' if rule.use_settings then local sid = task:get_settings_id() if sid then - id = rule.prefix .. tostring(sid) + id = tostring(sid) end end if rule.per_user then @@ -481,6 +471,7 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) end if is_fann_valid(rule, prefix, ann) then + if not fanns[id] then fanns[id] = {} end fanns[id].fann = ann rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', prefix, ver) @@ -627,6 +618,8 @@ local function train_fann(rule, _, ev_base, elt, worker) if string.match(err, 'NOSCRIPT') then load_scripts(rspamd_config, ev_base, nil) end + else + rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix) end end @@ -666,7 +659,7 @@ local function train_fann(rule, _, ev_base, elt, worker) true, -- is write redis_save_cb, --callback 'EVALSHA', -- command - {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)} + {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)} ) end end @@ -686,8 +679,8 @@ local function train_fann(rule, _, ev_base, elt, worker) {prefix .. '_locked'} ) else - rspamd_logger.infox(rspamd_config, 'trained ANN %s', - prefix) + rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes', + prefix, #data) local ann_data local f = torch.MemoryFile(torch.CharStorage():string(tostring(data))) ann_data = rspamd_util.zstd_compress(f:storage():string()) @@ -703,7 +696,7 @@ local function train_fann(rule, _, ev_base, elt, worker) true, -- is write redis_save_cb, --callback 'EVALSHA', -- command - {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)} + {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)} ) end end @@ -780,6 +773,8 @@ local function train_fann(rule, _, ev_base, elt, worker) local trainer = nn.StochasticGradient(fanns[elt].fann_train, criterion) trainer.learning_rate = 0.01 + trainer.verbose = false + trainer.maxIteration = rule.train.max_iterations trainer.hookIteration = function(self, iteration, currentError) rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s", iteration, currentError) @@ -980,18 +975,23 @@ end local function check_fanns(rule, _, ev_base) local function members_cb(err, data) if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) + rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', + err) elseif type(data) == 'table' then fun.each(function(elt) elt = tostring(elt) local redis_update_cb = function(_err, _data) if _err then - rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, _err) + rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', + elt, _err) if string.match(_err, 'NOSCRIPT') then load_scripts(rspamd_config, ev_base, nil) end elseif _data and type(_data) == 'table' then load_or_invalidate_fann(rule, _data, elt, ev_base) + else + rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s', + type(_data), elt) end end -- 2.39.5