]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Rework train data pushing part
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 12:39:46 +0000 (13:39 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 12:39:46 +0000 (13:39 +0100)
src/plugins/lua/neural.lua

index fdb138321f97eae90f00e29ba67be4ea0df496eb..f40778a7b1c26e7ace836bf10f4d4b43091860db 100644 (file)
@@ -65,7 +65,7 @@ local default_options = {
 -- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
 -- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
 -- * version - version of ANN loaded from redis
--- * ann_key - name of ANN key in Redis
+-- * redis_key - name of ANN key in Redis
 -- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
 -- * distance - distance between set.symbols and set.ann.symbols
 -- * ann - kann object
@@ -85,31 +85,25 @@ end
 
 -- Lua script to train a row
 -- Uses the following keys:
--- key1 - prefix for fann
--- key2 - fann suffix (settings id)
--- key3 - spam or ham
--- key4 - maximum trains
+-- key1 - ann key
+-- key2 - spam or ham
+-- key3 - maximum trains
 -- returns 1 or 0: 1 - allow learn, 0 - not allow learn
 local redis_lua_script_can_train = [[
-  local prefix = KEYS[1] .. KEYS[2]
-  local locked = redis.call('GET', prefix .. '_locked')
+  local prefix = KEYS[1]
+  local locked = redis.call('HGET', prefix, 'lock')
   if locked then return 0 end
   local nspam = 0
   local nham = 0
-  local lim = tonumber(KEYS[4])
+  local lim = tonumber(KEYS[3])
   lim = lim + lim * 0.1
 
-  local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2])
-  if not exists or tonumber(exists) == 0 then
-    redis.call('SADD', KEYS[1], KEYS[2])
-  end
-
   local ret = redis.call('LLEN', prefix .. '_spam')
   if ret then nspam = tonumber(ret) end
   ret = redis.call('LLEN', prefix .. '_ham')
   if ret then nham = tonumber(ret) end
 
-  if KEYS[3] == 'spam' then
+  if KEYS[2] == 'spam' then
     if nham <= lim and nham + 1 >= nspam then
       return tostring(nspam + 1)
     else
@@ -155,6 +149,9 @@ local redis_lua_script_maybe_invalidate = [[
     for _,k in ipairs(to_delete) do
       local tb = cjson.decode(k)
       redis.call('DEL', tb.ann_key)
+      -- Also train vectors
+      redis.call('DEL', tb.ann_key .. '_spam')
+      redis.call('DEL', tb.ann_key .. '_ham')
     end
     redis.call('ZREMRANGEBYRANK', KEYS[1], 0, (-(tonumber(KEYS[2] - 1)))
     return to_delete
@@ -328,9 +325,8 @@ local function create_ann(n, nlayers)
 end
 
 
-local function ann_train_callback(rule, task, score, required_score, id)
-  local train_opts = rule['train']
-  local fname,suffix = gen_ann_prefix(rule, id)
+local function ann_train_callback(rule, task, score, required_score, set)
+  local train_opts = rule.train
 
   local learn_spam, learn_ham
 
@@ -360,16 +356,15 @@ local function ann_train_callback(rule, task, score, required_score, id)
 
 
   if learn_spam or learn_ham then
-    local k
-    local vec_len = 0
-    if learn_spam then k = 'spam' else k = 'ham' end
+    local learn_type
+    if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
 
     local function learn_vec_cb(err)
       if err then
         rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err)
       else
         rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
-          rule['name'], k, vec_len)
+          rule['name'], learn_type, vec_len)
       end
     end
 
@@ -380,42 +375,36 @@ local function ann_train_callback(rule, task, score, required_score, id)
           rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
           return
         end
-        local ann_data = task:get_symbols_tokens()
-        local mt = meta_functions.rspamd_gen_metatokens(task)
-        -- Add filtered meta tokens
-        fun.each(function(e) table.insert(ann_data, e) end, mt)
-        -- Check NaNs in train data
-        if fun.all(function(e) return e == e end, ann_data) then
-          local str = rspamd_util.zstd_compress(table.concat(ann_data, ';'))
-          vec_len = #str
-
-          lua_redis.redis_make_request(task,
+        local vec = result_to_vector(task, set)
+
+        local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+
+        lua_redis.redis_make_request(task,
             rule.redis,
             nil,
             true, -- is write
             learn_vec_cb, --callback
             'LPUSH', -- command
-            {fname .. '_' .. k, str} -- arguments
-          )
-        else
-          rspamd_logger.errx(task, "do not store learn vector as it contains %s NaN values",
-            fun.length(fun.filter(function(e) return e ~= e end, ann_data)))
-        end
-
+            { set.ann.redis_prefix .. '_' .. learn_type, str} -- arguments
+        )
       else
         if err then
           rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err)
         elseif tonumber(data) < 0 then
-          rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s",
-            fname, k, -tonumber(data))
+          rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
+            rule.prefix, set.name, learn_type, -tonumber(data))
         end
       end
     end
 
+    if not set.ann then
+      -- Need to create or load a profile corresponding to the current configuration
+    end
+    -- Check if we can learn
     lua_redis.exec_redis_script(redis_can_train_id,
-      {task = task, is_write = true},
-      can_train_cb,
-      {gen_ann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)})
+        {task = task, is_write = true},
+        can_train_cb,
+        { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
   end
 end
 
@@ -742,7 +731,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             ann = ann,
             version = profile.version,
             symbols = profile.symbols,
-            distance = min_diff
+            distance = min_diff,
+            redis_key = profile.ann_key
           }
 
           rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
@@ -909,15 +899,12 @@ local function ann_push_vector(task)
   if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
   local scores = task:get_metric_score()
   for _,rule in pairs(settings.rules) do
-    local sid = "0"
-    if rule.use_settings then
-      sid = tostring(task:get_settings_id())
-    end
-    if rule.per_user then
-      local r = task:get_principal_recipient()
-      sid = sid .. r
+    local sid = task:get_settings_id() or -1
+
+    if rule.settings[sid] then
+      ann_train_callback(rule, task, scores[1], scores[2], rule.settings[sid])
     end
-    ann_train_callback(rule, task, scores[1], scores[2], sid)
+
   end
 end