]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Various fixes to fann_redis instantiation
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 12:28:45 +0000 (13:28 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 12:28:45 +0000 (13:28 +0100)
src/plugins/lua/fann_redis.lua

index ac6f78772dfd5f5044b6beba87beab80bda48472..09c20ebf9c214a28a892951a11528d2c9833438d 100644 (file)
@@ -43,12 +43,12 @@ local default_options = {
     max_trains = 1000,
     max_epoch = 1000,
     max_usages = 10,
-    use_settings = false,
-    per_user = false,
-    watch_interval = 60.0,
     mse = 0.001,
     autotrain = true,
   },
+  use_settings = false,
+  per_user = false,
+  watch_interval = 60.0,
   nlayers = 4,
   lock_expire = 600,
   learning_spawned = false,
@@ -58,8 +58,7 @@ local default_options = {
 }
 
 local settings = {
-  rules = {
-  }
+  rules = {}
 }
 
 -- ANNs indexed by settings id
@@ -96,9 +95,17 @@ local redis_lua_script_can_train = [[
   if ret then nham = tonumber(ret) end
 
   if KEYS[3] == 'spam' then
-    if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) end
+    if nham <= lim and nham + 1 >= nspam then
+      return tostring(nspam + 1)
+    else
+      return tostring(-(nham + 1))
+    end
   else
-    if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) end
+    if nspam <= lim and nspam + 1 >= nham then
+      return tostring(nham + 1)
+    else
+      return tostring(-(nspam + 1))
+    end
   end
 
   return tostring(0)
@@ -312,8 +319,7 @@ local function gen_fann_prefix(rule, id)
     tprefix = 't';
   end
   if id then
-    return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id),
-      rule.prefix .. id
+    return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id
   else
     return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
   end
@@ -433,7 +439,7 @@ local function create_train_fann(rule, n, id)
     if not is_fann_valid(rule, prefix, fanns[id].fann) then
       fanns[id].fann_train = create_fann(n, rule.nlayers)
       fanns[id].fann = nil
-    elseif fanns[id].version % rule.max_usages == 0 then
+    elseif fanns[id].version % rule.train.max_usages == 0 then
       -- Forget last fann
       rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
         fanns[id].version)
@@ -507,7 +513,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
 
   local learn_spam, learn_ham
 
-  if rule.autotrain then
+  if train_opts.autotrain then
     if train_opts['spam_score'] then
       learn_spam = score >= train_opts['spam_score']
     else
@@ -531,13 +537,18 @@ local function fann_train_callback(rule, task, score, required_score, id)
     end
   end
 
+
   if learn_spam or learn_ham then
     local k
+    local vec_len = 0
     if learn_spam then k = 'spam' else k = 'ham' end
 
     local function learn_vec_cb(err)
       if err then
         rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err)
+      else
+        rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
+          rule['name'], k, vec_len)
       end
     end
 
@@ -548,6 +559,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
         -- Add filtered meta tokens
         fun.each(function(e) table.insert(fann_data, e) end, mt)
         local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
+        vec_len = #str
 
         rspamd_redis.redis_make_request(task,
           rule.redis,
@@ -559,10 +571,13 @@ local function fann_train_callback(rule, task, score, required_score, id)
         )
       else
         if err then
-          rspamd_logger.errx(rspamd_config, 'cannot check if we can train %s: %s', fname, err)
+          rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err)
           if string.match(err, 'NOSCRIPT') then
             load_scripts(rspamd_config, task:get_ev_base(), nil)
           end
+        elseif tonumber(data) < 0 then
+          rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s",
+            fname, k, -tonumber(data))
         end
       end
     end
@@ -574,7 +589,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
       can_train_cb, --callback
       'EVALSHA', -- command
       {redis_can_train_sha, '4', gen_fann_prefix(rule, nil),
-        suffix, k, tostring(rule.max_trains)} -- arguments
+        suffix, k, tostring(train_opts.max_trains)} -- arguments
     )
   end
 end
@@ -722,7 +737,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
         create_train_fann(rule, n, elt)
       end
 
-      if #spam_elts + #ham_elts < rule.max_trains / 2 then
+      if #spam_elts + #ham_elts < rule.train.max_trains / 2 then
         -- Invalidate ANN as it is definitely invalid
         local function redis_invalidate_cb(_err, _data)
           if _err then
@@ -912,10 +927,10 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
             rspamd_logger.errx(rspamd_config,
               'cannot get FANN trains %s from redis: %s', prefix, _err)
           elseif _data and type(_data) == 'number' or type(_data) == 'string' then
-            if tonumber(_data) and tonumber(_data) >= rule.max_trains then
+            if tonumber(_data) and tonumber(_data) >= rule.train.max_trains then
               rspamd_logger.infox(rspamd_config,
                 'need to learn ANN %s after %s learn vectors (%s required)',
-                prefix, tonumber(_data), rule.max_trains)
+                prefix, tonumber(_data), rule.train.max_trains)
               train_fann(rule, cfg, ev_base, elt, worker)
             end
           end
@@ -1011,8 +1026,7 @@ end
 
 local function ann_push_vector(task)
   local scores = task:get_metric_score()
-
-  for _,rule in ipairs(settings.rules) do
+  for _,rule in pairs(settings.rules) do
     local sid = "0"
     if rule.use_settings then
       sid = tostring(task:get_settings_id())
@@ -1053,32 +1067,64 @@ else
     callback = fann_scores_filter
   })
 
+  local function deepcopy(orig)
+    local orig_type = type(orig)
+    local copy
+    if orig_type == 'table' then
+      copy = {}
+      for orig_key, orig_value in next, orig, nil do
+        copy[deepcopy(orig_key)] = deepcopy(orig_value)
+      end
+      setmetatable(copy, deepcopy(getmetatable(orig)))
+    else -- number, string, boolean, etc
+      copy = orig
+    end
+    return copy
+  end
+  local function override_defaults(def, override)
+    for k,v in pairs(def) do
+      if override[k] then
+        if def[k] then
+          if type(override[k]) == 'table' then
+            override_defaults(def[k], override[k])
+          else
+            def[k] = override[k]
+          end
+        else
+          def[k] = override[k]
+        end
+      end
+    end
+  end
   for k,r in pairs(rules) do
-    rules[k] = default_options
-    rules[k]['redis'] = redis_params
-    local cur = rules[k]
+    local def_rules = deepcopy(default_options)
+    def_rules['redis'] = redis_params
     -- Override defaults
-    for sk,v in pairs(r) do
-      cur[sk] = v
+    override_defaults(def_rules, r)
+
+    if not def_rules.prefix then
+      def_rules.prefix = k
     end
-    if not cur.prefix then
-      cur.prefix = k
+    if not def_rules.name then
+      def_rules.name = k
     end
+    rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
+    settings.rules[k] = def_rules
     rspamd_config:set_metric_symbol({
-      name = cur.symbol_spam,
+      name = def_rules.symbol_spam,
       score = 3.0,
       description = 'Neural network SPAM',
       group = 'fann'
     })
 
     rspamd_config:set_metric_symbol({
-      name = cur.symbol_ham,
+      name = def_rules.symbol_ham,
       score = -2.0,
       description = 'Neural network HAM',
       group = 'fann'
     })
     rspamd_config:register_symbol({
-      name = cur.symbol_ham,
+      name = def_rules.symbol_ham,
       type = 'virtual,nostat',
       parent = id
     })
@@ -1086,13 +1132,11 @@ else
 
   rspamd_config:register_symbol({
     name = 'FANN_VECTOR_PUSH',
-    type = 'postfilter,nostat',
+    type = 'idempotent,nostat',
     priority = 5,
     callback = ann_push_vector
   })
 
-  settings.rules = rules
-
   -- Add training scripts
   for _,rule in pairs(settings.rules) do
     rspamd_config:add_on_load(function(cfg, ev_base, worker)