]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Neural: Add sampling when storing training vectors
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 18 Oct 2019 16:08:44 +0000 (17:08 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 18 Oct 2019 16:08:44 +0000 (17:08 +0100)
src/plugins/lua/neural.lua

index 0b93cd4a74a17e4d84cf8bff8896b413672f9b83..7acb0eca3ed404567f84d06e4b5192d8e654cb38 100644 (file)
@@ -97,7 +97,7 @@ end
 -- key1 - ann key
 -- key2 - spam or ham
 -- key3 - maximum trains
--- returns 1 or 0: 1 - allow learn, 0 - not allow learn
+-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
 local redis_lua_script_can_store_train_vec = [[
   local prefix = KEYS[1]
   local locked = redis.call('HGET', prefix, 'lock')
@@ -114,19 +114,33 @@ local redis_lua_script_can_store_train_vec = [[
 
   if KEYS[2] == 'spam' then
     if nspam <= lim then
-      return tostring(nspam)
-    else
-      return tostring(-(nspam))
+      if nspam > nham then
+        -- Apply sampling
+        local skip_rate = 1.0 - nham / (nspam + 1)
+        if math.random() < skip_rate then
+          return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
+        end
+      end
+      return {tostring(nspam),'can learn'}
+    else -- Enough learns
+      return {tostring(-(nspam)),'too many spam samples'}
     end
   else
     if nham <= lim then
-      return tostring(nham)
+      if nsham > nspam then
+        -- Apply sampling
+        local skip_rate = 1.0 - nspam / (nham + 1)
+        if math.random() < skip_rate then
+          return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
+        end
+      end
+      return {tostring(nham),'can learn'}
     else
-      return tostring(-(nham))
+      return {tostring(-(nham)),'too many ham samples'}
     end
   end
 
-  return tostring(0)
+  return {tostring(0),'bad input'}
 ]]
 local redis_can_store_train_vec_id = nil
 
@@ -416,45 +430,50 @@ local function ann_push_task_result(rule, task, verdict, score, set)
     if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
 
     local function can_train_cb(err, data)
-      if not err and tonumber(data) >= 0 then
-        local coin = math.random()
-        if coin < 1.0 - train_opts.train_prob then
-          rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
-          return
-        end
-        local vec = result_to_vector(task, set)
+      if not err and type(data) == 'table' then
+        local nsamples,reason = tonumber(data[1]),data[2]
 
-        local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
-        local target_key = set.ann.redis_key .. '_' .. learn_type
+        if nsamples > 0 then
+          local coin = math.random()
 
-        local function learn_vec_cb(_err)
-          if _err then
-            rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
-                rule.prefix, set.name, _err)
-          else
-            lua_util.debugm(N, task,
-                "add train data for ANN rule " ..
-                "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
-                rule.prefix, set.name, learn_type, #vec, target_key, #str)
+          if coin < 1.0 - train_opts.train_prob then
+            rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+            return
           end
-        end
 
-        lua_redis.redis_make_request(task,
-            rule.redis,
-            nil,
-            true, -- is write
-            learn_vec_cb, --callback
-            'LPUSH', -- command
-            { target_key, str } -- arguments
-        )
-      else
-        if err then
-          rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
-              rule.prefix, set.name, err)
-        elseif tonumber(data) < 0 then
-          rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
-              rule.prefix, set.name, learn_type, -tonumber(data))
+          local vec = result_to_vector(task, set)
+
+          local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+          local target_key = set.ann.redis_key .. '_' .. learn_type
+
+          local function learn_vec_cb(_err)
+            if _err then
+              rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+                  rule.prefix, set.name, _err)
+            else
+              lua_util.debugm(N, task,
+                  "add train data for ANN rule " ..
+                      "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+                  rule.prefix, set.name, learn_type, #vec, target_key, #str)
+            end
+          end
+
+          lua_redis.redis_make_request(task,
+              rule.redis,
+              nil,
+              true, -- is write
+              learn_vec_cb, --callback
+              'LPUSH', -- command
+              { target_key, str } -- arguments
+          )
+        else
+          -- Negative result returned
+          rspamd_logger.infox(task, "cannot learn ANN %s:%s: %s (%s vectors stored)",
+              rule.prefix, set.name, learn_type, reason, -tonumber(nsamples))
         end
+      else
+        rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+            rule.prefix, set.name, err)
       end
     end
 
@@ -466,7 +485,11 @@ local function ann_push_task_result(rule, task, verdict, score, set)
     lua_redis.exec_redis_script(redis_can_store_train_vec_id,
         {task = task, is_write = true},
         can_train_cb,
-        { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
+        {
+          set.ann.redis_key,
+          learn_type,
+          tostring(train_opts.max_trains),
+        })
   else
     lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',
         skip_reason)
@@ -1128,6 +1151,13 @@ local function ann_push_vector(task)
     return
   end
 
+  if score ~= score then
+    lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
+        verdict)
+
+    return
+  end
+
   for _,rule in pairs(settings.rules) do
     local set = get_rule_settings(task, rule)