]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Neural: Moar fixes
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 8 Jul 2019 13:21:05 +0000 (14:21 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 8 Jul 2019 13:21:05 +0000 (14:21 +0100)
src/plugins/lua/neural.lua

index 7b6c2fa5f81b3c1341081a5edbc9f799b0af16d2..0375d57cdc991037b947fbd5b42e7593809867fc 100644 (file)
@@ -104,14 +104,14 @@ local redis_lua_script_can_store_train_vec = [[
   if ret then nham = tonumber(ret) end
 
   if KEYS[2] == 'spam' then
-    if nham <= lim and nham + 1 >= nspam then
-      return tostring(nspam + 1)
+    if nspam <= lim then
+      return tostring(nspam)
     else
       return tostring(-(nspam))
     end
   else
-    if nspam <= lim and nspam + 1 >= nham then
-      return tostring(nham + 1)
+    if nham <= lim then
+      return tostring(nham)
     else
       return tostring(-(nham))
     end
@@ -127,8 +127,9 @@ local redis_can_store_train_vec_id = nil
 -- key2 - number of elements to leave
 local redis_lua_script_maybe_invalidate = [[
   local card = redis.call('ZCARD', KEYS[1])
-  if card > tonumber(KEYS[2]) then
-    local to_delete = redis.call('ZRANGE', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
+  local lim = tonumber(KEYS[2])
+  if card > lim then
+    local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
     for _,k in ipairs(to_delete) do
       local tb = cjson.decode(k)
       redis.call('DEL', tb.redis_key)
@@ -136,7 +137,7 @@ local redis_lua_script_maybe_invalidate = [[
       redis.call('DEL', tb.redis_key .. '_spam')
       redis.call('DEL', tb.redis_key .. '_ham')
     end
-    redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1))))
+    redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
     return to_delete
   else
     return {}
@@ -152,17 +153,17 @@ local redis_maybe_invalidate_id = nil
 -- key4 - hostname
 local redis_lua_script_maybe_lock = [[
   local locked = redis.call('HGET', KEYS[1], 'lock')
+  local now = tonumber(KEYS[2])
   if locked then
     locked = tonumber(locked)
-    now = tonumber(KEYS[2])
-    expire = tonumber(KEYS[3])
+    local expire = tonumber(KEYS[3])
     if now > locked and (now - locked) < expire then
       return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
     end
   end
   redis.call('HSET', KEYS[1], 'lock', tostring(now))
   redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
-  return true
+  return 1
 ]]
 local redis_maybe_lock_id = nil
 
@@ -178,6 +179,8 @@ local redis_lua_script_save_unlock = [[
   local now = tonumber(KEYS[6])
   redis.call('ZADD', KEYS[2], now, KEYS[4])
   redis.call('HSET', KEYS[1], 'ann', KEYS[3])
+  redis.call('DEL', KEYS[1] .. '_spam')
+  edis.call('DEL', KEYS[1] .. '_ham')
   redis.call('HDEL', KEYS[1], 'lock')
   redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
   return 1
@@ -267,7 +270,7 @@ local function new_ann_profile(task, rule, set, version)
   }
 
   local ucl = require "ucl"
-  local profile_serialized = ucl.to_format(profile, 'json-compact')
+  local profile_serialized = ucl.to_format(profile, 'json-compact', true)
 
   local function add_cb(err, _)
     if err then
@@ -322,7 +325,8 @@ local function ann_scores_filter(task)
       score = out[1]
 
       local symscore = string.format('%.3f', score)
-      rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
+      lua_util.debugm(N, task, '%s:%s ann score: %s',
+          rule.prefix, set.name, symscore)
 
       if score > 0 then
         local result = score
@@ -348,26 +352,44 @@ end
 
 local function ann_push_task_result(rule, task, verdict, score, set)
   local train_opts = rule.train
-
-
   local learn_spam, learn_ham
+  local skip_reason = 'unknown'
 
   if train_opts.autotrain then
-    if verdict == 'passthrough' or verdict == 'uncertain' then
+    if verdict == 'passthrough' then
       lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
           verdict, score)
     end
 
-    if train_opts['spam_score'] then
-      learn_spam = score >= train_opts['spam_score']
+    if train_opts.spam_score then
+      learn_spam = score >= train_opts.spam_score
+
+      if not learn_spam then
+        skip_reason = string.format('score < spam_score: %f < %f',
+            score, train_opts.spam_score)
+      end
     else
       learn_spam = verdict == 'spam' or verdict == 'junk'
+
+      if not learn_spam then
+        skip_reason = string.format('verdict: %s',
+            verdict)
+      end
     end
 
-    if train_opts['ham_score'] then
-      learn_ham = score <= train_opts['ham_score']
+    if train_opts.ham_score then
+      learn_ham = score <= train_opts.ham_score
+      if not learn_ham then
+        skip_reason = string.format('score > ham_score: %f < %f',
+            score, train_opts.ham_score)
+      end
     else
       learn_ham = verdict == 'ham'
+
+      if not learn_ham then
+        skip_reason = string.format('verdict: %s',
+            verdict)
+      end
     end
   else
     -- Train by request header
@@ -378,6 +400,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         learn_spam = true
       elseif hdr:lower() == 'ham' then
         learn_ham = true
+      else
+        skip_reason = string.format('no explicit header')
       end
     end
   end
@@ -387,18 +411,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
     local learn_type
     if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
 
-    local function learn_vec_cb(err)
-      if err then
-        rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
-            rule.prefix, set.name, err)
-      else
-        rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector",
-          rule.prefix, set.name, learn_type)
-      end
-    end
-
     local function can_train_cb(err, data)
-      if not err and tonumber(data) > 0 then
+      if not err and tonumber(data) >= 0 then
         local coin = math.random()
         if coin < 1.0 - train_opts.train_prob then
           rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
@@ -408,6 +422,17 @@ local function ann_push_task_result(rule, task, verdict, score, set)
 
         local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
 
+        local function learn_vec_cb(_err)
+          if _err then
+            rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+                rule.prefix, set.name, _err)
+          else
+            lua_util.debugm(N, task,
+                "add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed",
+                rule.prefix, set.name, learn_type, #vec, #str)
+          end
+        end
+
         lua_redis.redis_make_request(task,
             rule.redis,
             nil,
@@ -422,7 +447,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
               rule.prefix, set.name, err)
         elseif tonumber(data) < 0 then
           rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
-            rule.prefix, set.name, learn_type, -tonumber(data))
+              rule.prefix, set.name, learn_type, -tonumber(data))
         end
       end
     end
@@ -436,6 +461,9 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         {task = task, is_write = true},
         can_train_cb,
         { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
+  else
+    lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',
+        skip_reason)
   end
 end
 
@@ -481,6 +509,7 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
               {ann_key, 'lock', '30'}
           )
         else
+          lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
           return false -- do not plan any more updates
         end
 
@@ -537,7 +566,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
       return out
     end
 
-    rule.learning_spawned = true
+    set.learning_spawned = true
 
     local function redis_save_cb(err)
       if err then
@@ -559,7 +588,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
     end
 
     local function ann_trained(err, data)
-      rule.learning_spawned = false
+      set.learning_spawned = false
       if err then
         rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
             rule.prefix, set.name, err)
@@ -598,7 +627,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         }
 
         local ucl = require "ucl"
-        local profile_serialized = ucl.to_format(profile, 'json-compact')
+        local profile_serialized = ucl.to_format(profile, 'json-compact', true)
 
         lua_redis.exec_redis_script(redis_save_unlock_id,
             {ev_base = ev_base, is_write = true},
@@ -695,7 +724,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
         ann_key, err)
-    elseif type(data) == 'boolean' and data then
+    elseif type(data) == 'number' and data == 1 then
       -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
@@ -752,47 +781,52 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
       rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
           ann_key, err)
     else
-      local _err,ann_data = rspamd_util.zstd_decompress(data[1])
-      local ann
-
-      if _err or not ann_data then
-        rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
-            rule.prefix .. ':' .. set.name, ann_key, _err)
-        return
-      else
-        ann = rspamd_kann.load(ann_data)
-
-        if ann then
-          set.ann = {
-            ann = ann,
-            version = profile.version,
-            symbols = profile.symbols,
-            distance = min_diff,
-            redis_key = profile.redis_key
-          }
+      if type(data) == 'string' then
+        local _err,ann_data = rspamd_util.zstd_decompress(data)
+        local ann
 
-          local ucl = require "ucl"
-          local profile_serialized = ucl.to_format(profile, 'json-compact')
-
-          local function rank_cb(_, _)
-            -- TODO: maybe add some logging
-          end
-          -- Also update rank for the loaded ANN to avoid removal
-          lua_redis.redis_make_request_taskless(ev_base,
-              rspamd_config,
-              rule.redis,
-              nil,
-              true, -- is write
-              rank_cb, --callback
-              'ZADD', -- command
-              {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
-          )
-          rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
-              rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
+        if _err or not ann_data then
+          rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+              rule.prefix .. ':' .. set.name, ann_key, _err)
+          return
         else
-          rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s',
-              rule.prefix .. ':' .. set.name, ann_key)
+          ann = rspamd_kann.load(ann_data)
+
+          if ann then
+            set.ann = {
+              ann = ann,
+              version = profile.version,
+              symbols = profile.symbols,
+              distance = min_diff,
+              redis_key = profile.redis_key
+            }
+
+            local ucl = require "ucl"
+            local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+            local function rank_cb(_, _)
+              -- TODO: maybe add some logging
+            end
+            -- Also update rank for the loaded ANN to avoid removal
+            lua_redis.redis_make_request_taskless(ev_base,
+                rspamd_config,
+                rule.redis,
+                nil,
+                true, -- is write
+                rank_cb, --callback
+                'ZADD', -- command
+                {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
+            )
+            rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                rule.prefix, set.name, ann_key, #ann_data, profile.version)
+          else
+            rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s',
+                rule.prefix, set.name, ann_key)
+          end
         end
+      else
+        lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s',
+            rule.prefix, set.name, ann_key)
       end
     end
   end
@@ -803,8 +837,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
       false, -- is write
       data_cb, --callback
       'HGET', -- command
-      {ann_key, 'ann'}, -- arguments
-      {opaque_data = true}
+      {ann_key, 'ann'} -- arguments
   )
 end
 
@@ -900,23 +933,46 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
     -- We have our ANN and that's train vectors, check if we can learn
     local ann_key = sel_elt.redis_key
 
-    lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key)
-    local redis_len_cb = function(err, data)
-      if err then
-        rspamd_logger.errx(rspamd_config,
-            'cannot get FANN trains %s from redis: %s', ann_key, err)
-      elseif data and type(data) == 'number' or type(data) == 'string' then
-        if tonumber(data) and tonumber(data) >= rule.train.max_trains then
-          rspamd_logger.infox(rspamd_config,
-              'need to learn ANN %s after %s learn vectors (%s required)',
-              ann_key, tonumber(data), rule.train.max_trains)
-          do_train_ann(worker, ev_base, rule, set, ann_key)
-        else
-          rspamd_logger.debugm(N, rspamd_config,
-              'no need to learn ANN %s %s learn vectors (%s required)',
-              ann_key, tonumber(data), rule.train.max_trains)
+    lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
+        ann_key)
+
+    -- Create continuation closure
+    local redis_len_cb_gen = function(cont_cb)
+      return function(err, data)
+        if err then
+          rspamd_logger.errx(rspamd_config,
+              'cannot get ANN trains %s from redis: %s', ann_key, err)
+        elseif data and type(data) == 'number' or type(data) == 'string' then
+          if tonumber(data) and tonumber(data) >= rule.train.max_trains then
+            cont_cb()
+          else
+            rspamd_logger.debugm(N, rspamd_config,
+                'no need to learn ANN %s %s learn vectors (%s required)',
+                ann_key, tonumber(data), rule.train.max_trains)
+          end
         end
       end
+
+    end
+
+    local function initiate_train()
+      rspamd_logger.infox(rspamd_config,
+          'need to learn ANN %s after %s learn vectors (%s required)',
+          ann_key, tonumber(data), rule.train.max_trains)
+      do_train_ann(worker, ev_base, rule, set, ann_key)
+    end
+
+    -- Spam vector is OK, check ham vector length
+    local function check_ham_len()
+      lua_redis.redis_make_request_taskless(ev_base,
+          rspamd_config,
+          rule.redis,
+          nil,
+          false, -- is write
+          redis_len_cb_gen(initiate_train), --callback
+          'LLEN', -- command
+          {ann_key .. '_ham'}
+      )
     end
 
     lua_redis.redis_make_request_taskless(ev_base,
@@ -924,7 +980,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
         rule.redis,
         nil,
         false, -- is write
-        redis_len_cb, --callback
+        redis_len_cb_gen(check_ham_len), --callback
         'LLEN', -- command
         {ann_key .. '_spam'}
     )
@@ -1005,14 +1061,22 @@ local function cleanup_anns(rule, cfg, ev_base)
 end
 
 local function ann_push_vector(task)
-  if task:has_flag('skip') then return end
-  if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
+  if task:has_flag('skip') then
+    lua_util.debugm(N, task, 'do not push data for skipped task')
+    return
+  end
+  if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
+    lua_util.debugm(N, task, 'do not push data for manual scan')
+    return
+  end
   local verdict,score = lua_util.get_task_verdict(task)
   for _,rule in pairs(settings.rules) do
     local set = get_rule_settings(task, rule)
 
     if set then
       ann_push_task_result(rule, task, verdict, score, set)
+    else
+      lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
     end
 
   end
@@ -1064,7 +1128,7 @@ local function process_rules_settings()
 
     if rule.default then
       local default_settings = {
-        symbols = lua_util.keys(lua_settings.default_symbols()),
+        symbols = lua_settings.default_symbols(),
         name = 'default'
       }
 
@@ -1099,7 +1163,7 @@ local function process_rules_settings()
 
       if nelt then
         rule.settings[s] = nelt
-        lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s; same symbols',
+        lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s',
             nelt.name, rule.prefix)
       end
     end