]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Further fixes to ANN module
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 14:33:26 +0000 (15:33 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 14:33:26 +0000 (15:33 +0100)
src/plugins/lua/fann_redis.lua

index 09c20ebf9c214a28a892951a11528d2c9833438d..2473fb290a92c8137fb1611a47c7d8fdd414584f 100644 (file)
@@ -98,13 +98,13 @@ local redis_lua_script_can_train = [[
     if nham <= lim and nham + 1 >= nspam then
       return tostring(nspam + 1)
     else
-      return tostring(-(nham + 1))
+      return tostring(-(nspam))
     end
   else
     if nspam <= lim and nspam + 1 >= nham then
       return tostring(nham + 1)
     else
-      return tostring(-(nspam + 1))
+      return tostring(-(nham))
     end
   end
 
@@ -411,9 +411,11 @@ local function create_fann(n, nlayers)
     -- We ignore number of layers so far when using torch
     local ann = nn.Sequential()
     local nhidden = math.floor((n + 1) / 2)
+    ann:add(nn.NaN(nn.Identity()))
     ann:add(nn.Linear(n, nhidden))
     ann:add(nn.PReLU())
     ann:add(nn.Linear(nhidden, 1))
+    ann:add(nn.Tanh())
 
     return ann
   else
@@ -429,7 +431,6 @@ local function create_fann(n, nlayers)
 end
 
 local function create_train_fann(rule, n, id)
-  id = rule.prefix .. tostring(id)
   local prefix = gen_fann_prefix(rule, id)
   if not fanns[id] then
     fanns[id] = {}
@@ -439,6 +440,7 @@ local function create_train_fann(rule, n, id)
     if not is_fann_valid(rule, prefix, fanns[id].fann) then
       fanns[id].fann_train = create_fann(n, rule.nlayers)
       fanns[id].fann = nil
+      rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
     elseif fanns[id].version % rule.train.max_usages == 0 then
       -- Forget last fann
       rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
@@ -446,9 +448,11 @@ local function create_train_fann(rule, n, id)
       fanns[id].fann_train = create_fann(n, rule.nlayers)
     else
       fanns[id].fann_train = fanns[id].fann
+      rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
     end
   else
     fanns[id].fann_train = create_fann(n, rule.nlayers)
+    rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
     fanns[id].version = 0
   end
 end
@@ -764,12 +768,12 @@ local function train_fann(rule, _, ev_base, elt, worker)
           local dataset = {}
           fun.each(function(s)
             table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})})
-          end, spam_elts)
+          end, fun.filter(filt, spam_elts))
           fun.each(function(s)
             table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})})
-          end, ham_elts)
+          end, fun.filter(filt, ham_elts))
           -- Needed for torch
-          dataset.size = function(tbl) return #tbl end
+          dataset.size = function() return #dataset end
 
           local function train_torch()
             local criterion = nn.MSECriterion()
@@ -922,6 +926,7 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
       fun.each(function(elt)
         elt = tostring(elt)
         local prefix = gen_fann_prefix(rule, elt)
+        rspamd_logger.infox(cfg, "check ANN %s", prefix)
         local redis_len_cb = function(_err, _data)
           if _err then
             rspamd_logger.errx(rspamd_config,
@@ -932,6 +937,10 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
                 'need to learn ANN %s after %s learn vectors (%s required)',
                 prefix, tonumber(_data), rule.train.max_trains)
               train_fann(rule, cfg, ev_base, elt, worker)
+            else
+              rspamd_logger.infox(rspamd_config,
+                'no need to learn ANN %s %s learn vectors (%s required)',
+                prefix, tonumber(_data), rule.train.max_trains)
             end
           end
         end
@@ -1082,17 +1091,15 @@ else
     return copy
   end
   local function override_defaults(def, override)
-    for k,v in pairs(def) do
-      if override[k] then
-        if def[k] then
-          if type(override[k]) == 'table' then
-            override_defaults(def[k], override[k])
-          else
-            def[k] = override[k]
-          end
+    for k,v in pairs(override) do
+      if def[k] then
+        if type(override[k]) == 'table' then
+          override_defaults(def[k], override[k])
         else
           def[k] = override[k]
         end
+      else
+        def[k] = override[k]
       end
     end
   end
@@ -1108,6 +1115,9 @@ else
     if not def_rules.name then
       def_rules.name = k
     end
+    if def_rules.train.max_train then
+      def_rules.train.max_trains = def_rules.train.max_train
+    end
     rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
     settings.rules[k] = def_rules
     rspamd_config:set_metric_symbol({
@@ -1144,7 +1154,7 @@ else
           check_fanns(rule, cfg, ev_base)
       end)
 
-      if worker:get_name() == 'normal' then
+      if worker:get_name() == 'controller' and worker:get_index() == 0 then
         -- We also want to train neural nets when they have enough data
         rspamd_config:add_periodic(ev_base, 0.0,
           function(_, _)