From d203a22d5c410bb0b04df0ce10e1293e169a5a68 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 16 Sep 2017 15:33:26 +0100 Subject: [PATCH] [Fix] Further fixes to ANN module --- src/plugins/lua/fann_redis.lua | 40 +++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 09c20ebf9..2473fb290 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -98,13 +98,13 @@ local redis_lua_script_can_train = [[ if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) else - return tostring(-(nham + 1)) + return tostring(-(nspam)) end else if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) else - return tostring(-(nspam + 1)) + return tostring(-(nham)) end end @@ -411,9 +411,11 @@ local function create_fann(n, nlayers) -- We ignore number of layers so far when using torch local ann = nn.Sequential() local nhidden = math.floor((n + 1) / 2) + ann:add(nn.NaN(nn.Identity())) ann:add(nn.Linear(n, nhidden)) ann:add(nn.PReLU()) ann:add(nn.Linear(nhidden, 1)) + ann:add(nn.Tanh()) return ann else @@ -429,7 +431,6 @@ local function create_fann(n, nlayers) end local function create_train_fann(rule, n, id) - id = rule.prefix .. tostring(id) local prefix = gen_fann_prefix(rule, id) if not fanns[id] then fanns[id] = {} @@ -439,6 +440,7 @@ local function create_train_fann(rule, n, id) if not is_fann_valid(rule, prefix, fanns[id].fann) then fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].fann = nil + rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) elseif fanns[id].version % rule.train.max_usages == 0 then -- Forget last fann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, @@ -446,9 +448,11 @@ local function create_train_fann(rule, n, id) fanns[id].fann_train = create_fann(n, rule.nlayers) else fanns[id].fann_train = fanns[id].fann + rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix) end else fanns[id].fann_train = create_fann(n, rule.nlayers) + rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix) fanns[id].version = 0 end end @@ -764,12 +768,12 @@ local function train_fann(rule, _, ev_base, elt, worker) local dataset = {} fun.each(function(s) table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})}) - end, spam_elts) + end, fun.filter(filt, spam_elts)) fun.each(function(s) table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})}) - end, ham_elts) + end, fun.filter(filt, ham_elts)) -- Needed for torch - dataset.size = function(tbl) return #tbl end + dataset.size = function() return #dataset end local function train_torch() local criterion = nn.MSECriterion() @@ -922,6 +926,7 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) fun.each(function(elt) elt = tostring(elt) local prefix = gen_fann_prefix(rule, elt) + rspamd_logger.infox(cfg, "check ANN %s", prefix) local redis_len_cb = function(_err, _data) if _err then rspamd_logger.errx(rspamd_config, @@ -932,6 +937,10 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) 'need to learn ANN %s after %s learn vectors (%s required)', prefix, tonumber(_data), rule.train.max_trains) train_fann(rule, cfg, ev_base, elt, worker) + else + rspamd_logger.infox(rspamd_config, + 'no need to learn ANN %s %s learn vectors (%s required)', + prefix, tonumber(_data), rule.train.max_trains) end end end @@ -1082,17 +1091,15 @@ else return copy end local function override_defaults(def, override) - for k,v in pairs(def) do - if override[k] then - if def[k] then - if type(override[k]) == 'table' then - override_defaults(def[k], override[k]) - else - def[k] = override[k] - end + for k,v in pairs(override) do + if def[k] then + if type(override[k]) == 'table' then + override_defaults(def[k], override[k]) else def[k] = override[k] end + else + def[k] = override[k] end end end @@ -1108,6 +1115,9 @@ else if not def_rules.name then def_rules.name = k end + if def_rules.train.max_train then + def_rules.train.max_trains = def_rules.train.max_train + end rspamd_logger.infox(rspamd_config, "register ann rule %s", k) settings.rules[k] = def_rules rspamd_config:set_metric_symbol({ @@ -1144,7 +1154,7 @@ else check_fanns(rule, cfg, ev_base) end) - if worker:get_name() == 'normal' then + if worker:get_name() == 'controller' and worker:get_index() == 0 then -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, function(_, _) -- 2.39.5