]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add training and saving ANN logic
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 19:52:40 +0000 (20:52 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 6 Jul 2019 19:52:40 +0000 (20:52 +0100)
src/plugins/lua/neural.lua

index c2ffb3e15f9dc7e9658c105508d9c971dba7b3cf..4b4cc7354594eac2628e25441d4040cc39511435 100644 (file)
@@ -83,13 +83,13 @@ if not opts then
 end
 
 
--- Lua script to train a row
+-- Lua script that checks if we can store a new training vector
 -- Uses the following keys:
 -- key1 - ann key
 -- key2 - spam or ham
 -- key3 - maximum trains
 -- returns 1 or 0: 1 - allow learn, 0 - not allow learn
-local redis_lua_script_can_train = [[
+local redis_lua_script_can_store_train_vec = [[
   local prefix = KEYS[1]
   local locked = redis.call('HGET', prefix, 'lock')
   if locked then return 0 end
@@ -119,7 +119,7 @@ local redis_lua_script_can_train = [[
 
   return tostring(0)
 ]]
-local redis_can_train_id = nil
+local redis_can_store_train_vec_id = nil
 
 -- Lua script to invalidate ANNs by rank
 -- Uses the following keys
@@ -144,20 +144,6 @@ local redis_lua_script_maybe_invalidate = [[
 ]]
 local redis_maybe_invalidate_id = nil
 
--- Lua script to invalidate ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
-local redis_lua_script_locked_invalidate = [[
-  redis.call('SET', KEYS[1] .. '_version', '0')
-  redis.call('DEL', KEYS[1] .. '_spam')
-  redis.call('DEL', KEYS[1] .. '_ham')
-  redis.call('DEL', KEYS[1] .. '_data')
-  redis.call('DEL', KEYS[1] .. '_locked')
-  redis.call('DEL', KEYS[1] .. '_hostname')
-  return 1
-]]
-local redis_locked_invalidate_id = nil
-
 -- Lua script to invalidate ANN from redis
 -- Uses the following keys
 -- key1 - prefix for keys
@@ -165,15 +151,18 @@ local redis_locked_invalidate_id = nil
 -- key3 - key expire
 -- key4 - hostname
 local redis_lua_script_maybe_lock = [[
-  local locked = redis.call('GET', KEYS[1] .. '_locked')
+  local locked = redis.call('HGET', KEYS[1], 'lock')
   if locked then
-    if tonumber(KEYS[2]) < tonumber(locked) then
-      return false
+    locked = tonumber(locked)
+    now = tonumber(KEYS[2])
+    expire = tonumber(KEYS[3])
+    if now > locked and (now - locked) < expire then
+      return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname')}
     end
   end
-  redis.call('SET', KEYS[1] .. '_locked', tostring(tonumber(KEYS[2]) + tonumber(KEYS[3])))
-  redis.call('SET', KEYS[1] .. '_hostname', KEYS[4])
-  return 1
+  redis.call('HSET', KEYS[1], 'lock', tostring(now))
+  redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
+  return true
 ]]
 local redis_maybe_lock_id = nil
 
@@ -198,12 +187,10 @@ local redis_save_unlock_id = nil
 local redis_params
 
 local function load_scripts(params)
-  redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train,
+  redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec,
     params)
   redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
     params)
-  redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate,
-    params)
   redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
     params)
   redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
@@ -244,14 +231,13 @@ local function new_ann_key(rule, set)
 end
 
 -- Creates and stores ANN profile in Redis
-local function new_ann_profile(task, rule, set)
+local function new_ann_profile(task, rule, set, version)
   local ann_key = new_ann_key(rule, set)
 
-
   local profile = {
     symbols = set.symbols,
     redis_key = ann_key,
-    version = 0,
+    version = version or 0,
     digest = set.digest,
     distance = 0 -- Since we are using our own profile
   }
@@ -269,7 +255,7 @@ local function new_ann_profile(task, rule, set)
     end
   end
 
-  lua_redis.redis_make_request_taskless(ev_base,
+  lua_redis.redis_make_request(task,
       rspamd_config,
       rule.redis,
       nil,
@@ -335,10 +321,10 @@ local function ann_scores_filter(task)
 
       if score > 0 then
         local result = score
-        task:insert_result(rule.symbol_spam, result, symscore, id)
+        task:insert_result(rule.symbol_spam, result, symscore)
       else
         local result = -(score)
-        task:insert_result(rule.symbol_ham, result, symscore, id)
+        task:insert_result(rule.symbol_ham, result, symscore)
       end
     end
   end
@@ -391,10 +377,11 @@ local function ann_train_callback(rule, task, score, required_score, set)
 
     local function learn_vec_cb(err)
       if err then
-        rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err)
+        rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+            rule.prefix, set.name, err)
       else
-        rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
-          rule['name'], learn_type, vec_len)
+        rspamd_logger.infox(task, "trained ANN rule %s:%s, save %s vector",
+          rule.prefix, set.name, learn_type)
       end
     end
 
@@ -433,181 +420,236 @@ local function ann_train_callback(rule, task, score, required_score, set)
       set.ann = new_ann_profile(task, rule, set)
     end
     -- Check if we can learn
-    lua_redis.exec_redis_script(redis_can_train_id,
+    lua_redis.exec_redis_script(redis_can_store_train_vec_id,
         {task = task, is_write = true},
         can_train_cb,
         { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
   end
 end
 
-local function train_ann(rule, _, ev_base, elt, worker)
-  local spam_elts = {}
-  local ham_elts = {}
-  elt = tostring(elt)
-  local prefix = gen_ann_prefix(rule, elt)
+--- Offline training logic
 
-  local function redis_unlock_cb(err)
+-- Closure generator for unlock function
+local function gen_unlock_cb(rule, set, ann_key)
+  return function (err)
     if err then
-      rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
-        prefix, err)
+      rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
+          rule.prefix, set.name, ann_key, err)
+    else
+      lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
+          rule.prefix, set.name, ann_key)
     end
   end
+end
 
-  local function redis_save_cb(err)
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
-        prefix, err)
-      lua_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        false, -- is write
-        redis_unlock_cb, --callback
-        'DEL', -- command
-        {prefix .. '_locked'}
-      )
-    else
-      rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix)
+-- This function is intended to extend lock for ANN during training
+-- It registers periodic that increases locked key each 30 seconds unless
+-- `set.learning_spawned` is set to `true`
+local function register_lock_extender(rule, set, ev_base, ann_key)
+  rspamd_config:add_periodic(ev_base, 30.0,
+      function()
+        local function redis_lock_extend_cb(_err, _)
+          if _err then
+            rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+                ann_key, _err)
+          else
+            rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+                ann_key)
+          end
+        end
+
+        if set.learning_spawned then
+          lua_redis.redis_make_request_taskless(ev_base,
+              rspamd_config,
+              rule.redis,
+              nil,
+              true, -- is write
+              redis_lock_extend_cb, --callback
+              'HINCRBY', -- command
+              {ann_key, 'lock', '30'}
+          )
+        else
+          return false -- do not plan any more updates
+        end
+
+        return true
+      end
+  )
+end
+
+-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
+local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
+  -- Check training data sanity
+  -- Now we need to join inputs and create the appropriate test vectors
+  local n = #set.symbols +
+      meta_functions.rspamd_count_metatokens()
+
+  -- Now we can train ann
+  local train_ann = create_ann(n, 3)
+
+  if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
+    -- Invalidate ANN as it is definitely invalid
+    -- TODO: add invalidation
+    assert(false)
+  else
+    local inputs, outputs = {}, {}
+
+    -- Make training set by joining vectors
+    -- KANN automatically shuffles those samples
+    -- 1.0 is used for spam and -1.0 is used for ham
+    -- It implies that output layer can express that (e.g. tanh output)
+    for _,e in ipairs(spam_vec) do
+      inputs[#inputs + 1] = e
+      outputs[#outputs + 1] = {1.0}
+    end
+    for _,e in ipairs(ham_vec) do
+      inputs[#inputs + 1] = e
+      outputs[#outputs + 1] = {-1.0}
     end
-  end
 
-  local function ann_trained(err, data)
-    rule.learning_spawned = false
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
-          prefix, err)
-      lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          true, -- is write
-          redis_unlock_cb, --callback
-          'DEL', -- command
-          {prefix .. '_locked'}
-      )
-    else
-      rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
-          prefix, #data)
-      local ann_data = rspamd_util.zstd_compress(data)
-      rule.anns[elt].ann_train = rspamd_kann.load(data)
-      rule.anns[elt].version = rule.anns[elt].version + 1
-      rule.anns[elt].ann = rule.anns[elt].ann_train
-      rule.anns[elt].ann_train = nil
-      lua_redis.exec_redis_script(redis_save_unlock_id,
-        {ev_base = ev_base, is_write = true},
-        redis_save_cb,
-        {prefix, tostring(ann_data), tostring(rule.ann_expire)})
+    -- Called in child process
+    local function train()
+      train_ann:train1(inputs, outputs, {
+        lr = rule.train.learning_rate,
+        max_epoch = rule.train.max_iterations,
+        cb = function(iter, train_cost, _)
+          if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
+            rspamd_logger.infox(rspamd_config, "ANN %s:%s: learned %s iterations, error: %s",
+                rule.prefix, set.name,
+                iter, train_cost)
+          end
+        end
+      })
+
+      local out = train_ann:save()
+      return out
     end
+
+    rule.learning_spawned = true
+
+    local function redis_save_cb(err)
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
+            rule.prefix, set.name, ann_key, err)
+        lua_redis.redis_make_request_taskless(ev_base,
+            rspamd_config,
+            rule.redis,
+            nil,
+            false, -- is write
+            gen_unlock_cb(rule, set, ann_key), --callback
+            'HDEL', -- command
+            {ann_key, 'lock'}
+        )
+      else
+        rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
+            rule.prefix, set.name, ann_key)
+      end
+    end
+
+    local function ann_trained(err, data)
+      rule.learning_spawned = false
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
+            rule.prefix, set.name, err)
+        lua_redis.redis_make_request_taskless(ev_base,
+            rspamd_config,
+            rule.redis,
+            nil,
+            true, -- is write
+            gen_unlock_cb(rule, set, ann_key), --callback
+            'HDEL', -- command
+            {ann_key, 'lock'}
+        )
+      else
+        rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes',
+            rule.prefix, set.name, #data)
+        local ann_data = rspamd_util.zstd_compress(data)
+        if not set.ann then
+          set.ann = {
+            symbols = set.symbols,
+            distance = 0,
+            digest = set.digest,
+            redis_key = ann_key,
+          }
+        end
+        -- Deserialise ANN from the child process
+        ann_trained = rspamd_kann.load(data)
+        set.ann.version = (set.ann.version or 0) + 1
+        set.ann.ann = ann_trained
+
+        lua_redis.exec_redis_script(redis_save_unlock_id,
+            {ev_base = ev_base, is_write = true},
+            redis_save_cb,
+            {ann_key, tostring(ann_data), tostring(rule.ann_expire)})
+      end
+    end
+
+    worker:spawn_process{
+      func = train,
+      on_complete = ann_trained,
+    }
   end
+  -- Spawn learn and register lock extension
+  set.learning_spawned = true
+  register_lock_extender(rule, set, ev_base, ann_key)
+end
+
+-- Utility to extract and split saved training vectors to a table of tables
+local function process_training_vectors(data)
+  return fun.totable(fun.map(function(tok)
+    local _,str = rspamd_util.zstd_decompress(tok)
+    return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
+  end, data))
+end
+
+-- This function does the following:
+-- * Tries to lock ANN
+-- * Loads spam and ham vectors
+-- * Spawn learning process
+local function do_train_ann(worker, ev_base, rule, set, ann_key)
+  local spam_elts = {}
+  local ham_elts = {}
 
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
-        prefix, err)
+        ann_key, err)
+      -- Unlock on error
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
         true, -- is write
-        redis_unlock_cb, --callback
-        'DEL', -- command
-        {prefix .. '_locked'}
+          gen_unlock_cb(rule, set, ann_key), --callback
+        'HDEL', -- command
+        {ann_key, 'lock'}
       )
     else
       -- Decompress and convert to numbers each training vector
-      ham_elts = fun.totable(fun.map(function(tok)
-        local _,str = rspamd_util.zstd_decompress(tok)
-        return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';')))
-      end, data))
-
-      -- Now we need to join inputs and create the appropriate test vectors
-      local n = rspamd_config:get_symbols_count() +
-          meta_functions.rspamd_count_metatokens()
-
-      -- Now we can train ann
-      if not rule.anns[elt] or not rule.anns[elt].ann_train then
-        -- Create ann if it does not exist
-        create_train_ann(rule, n, elt)
-      end
-
-      if #spam_elts + #ham_elts < rule.train.max_trains / 2 then
-        -- Invalidate ANN as it is definitely invalid
-        local function redis_invalidate_cb(_err, _data)
-          if _err then
-            rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
-          elseif type(_data) == 'string' then
-            rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
-            rule.anns[elt].version = 0
-          end
-        end
-        -- Invalidate ANN
-        rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
-        lua_redis.exec_redis_script(redis_locked_invalidate_id,
-          {ev_base = ev_base, is_write = true},
-          redis_invalidate_cb,
-          {prefix})
-      else
-        local inputs, outputs = {}, {}
-
-        for _,e in ipairs(spam_elts) do
-          if e == e then
-            inputs[#inputs + 1] = e
-            outputs[#outputs + 1] = {1.0}
-          end
-        end
-        for _,e in ipairs(ham_elts) do
-          if e == e then
-            inputs[#inputs + 1] = e
-            outputs[#outputs + 1] = {0.0}
-          end
-        end
-
-
-        local function train()
-          rule.anns[elt].ann_train:train1(inputs, outputs, {
-            lr = rule.train.learning_rate,
-            max_epoch = rule.train.max_iterations,
-            cb = function(iter, train_cost, _)
-              if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then
-                rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
-                    iter, train_cost)
-              end
-            end
-          })
-
-          local out = rule.anns[elt].ann_train:save()
-          return out
-        end
-
-        rule.learning_spawned = true
-
-        worker:spawn_process{
-          func = train,
-          on_complete = ann_trained,
-        }
-      end
+      ham_elts = process_training_vectors(data)
+      spawn_train(worker, ev_base, rule, set, ann_key, ham_elts, spam_elts)
     end
   end
 
+  -- Spam vectors received
   local function redis_spam_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
-        prefix, err)
+        ann_key, err)
+      -- Unlock ANN on error
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
         true, -- is write
-        redis_unlock_cb, --callback
-        'DEL', -- command
-        {prefix .. '_locked'}
+          gen_unlock_cb(rule, set, ann_key), --callback
+        'HDEL', -- command
+        {ann_key, 'lock'}
       )
     else
       -- Decompress and convert to numbers each training vector
-      spam_elts = fun.totable(fun.map(function(tok)
-        local _,str = rspamd_util.zstd_decompress(tok)
-        return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';')))
-      end, data))
+      spam_elts = process_training_vectors(data)
+      -- Now get ham vectors...
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
@@ -615,17 +657,17 @@ local function train_ann(rule, _, ev_base, elt, worker)
         false, -- is write
         redis_ham_cb, --callback
         'LRANGE', -- command
-        {prefix .. '_ham', '0', '-1'}
+        {ann_key .. '_ham', '0', '-1'}
       )
     end
   end
 
   local function redis_lock_cb(err, data)
     if err then
-      rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
-        prefix, err)
-    elseif type(data) == 'number' then
-      -- Can train ANN
+      rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
+        ann_key, err)
+    elseif type(data) == 'boolean' and data then
+      -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
       lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
@@ -633,105 +675,38 @@ local function train_ann(rule, _, ev_base, elt, worker)
         false, -- is write
         redis_spam_cb, --callback
         'LRANGE', -- command
-        {prefix .. '_spam', '0', '-1'}
+        {ann_key .. '_spam', '0', '-1'}
       )
 
-      rspamd_config:add_periodic(ev_base, 30.0,
-        function(_, _)
-          local function redis_lock_extend_cb(_err, _)
-            if _err then
-              rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
-                prefix, _err)
-            else
-              rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
-                prefix)
-            end
-          end
-          if rule.learning_spawned then
-            lua_redis.redis_make_request_taskless(ev_base,
-              rspamd_config,
-              rule.redis,
-              nil,
-              true, -- is write
-              redis_lock_extend_cb, --callback
-              'INCRBY', -- command
-              {prefix .. '_locked', '30'}
-            )
-          else
-            return false -- do not plan any more updates
-          end
-
-          return true
-        end
-      )
-      rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', prefix)
+      rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
+        rule.prefix, set.name, ann_key)
     else
-      rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix)
+      local lock_tm = tonumber(data[1])
+      rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
+          'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+          data[2], os.date('%c', lock_tm))
     end
   end
-  if rule.learning_spawned then
-    rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
+
+  -- Check if we are already learning this network
+  if set.learning_spawned then
+    rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
+        ann_key)
     return
   end
+
+  -- Call Redis script that tries to acquire a lock
+  -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
+  -- ANN is locked by another host (or a process, meh)
   lua_redis.exec_redis_script(redis_maybe_lock_id,
     {ev_base = ev_base, is_write = true},
     redis_lock_cb,
-    {prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()})
-end
-
-local function maybe_train_anns(rule, cfg, ev_base, worker)
-  local function members_cb(err, data)
-    if err then
-      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
-    elseif type(data) == 'table' then
-      fun.each(function(elt)
-        elt = tostring(elt)
-        local prefix = gen_ann_prefix(rule, elt)
-        rspamd_logger.infox(cfg, "check ANN %s", prefix)
-        local redis_len_cb = function(_err, _data)
-          if _err then
-            rspamd_logger.errx(rspamd_config,
-              'cannot get FANN trains %s from redis: %s', prefix, _err)
-          elseif _data and type(_data) == 'number' or type(_data) == 'string' then
-            if tonumber(_data) and tonumber(_data) >= rule.train.max_trains then
-              rspamd_logger.infox(rspamd_config,
-                'need to learn ANN %s after %s learn vectors (%s required)',
-                prefix, tonumber(_data), rule.train.max_trains)
-              train_ann(rule, cfg, ev_base, elt, worker)
-            else
-              rspamd_logger.infox(rspamd_config,
-                'no need to learn ANN %s %s learn vectors (%s required)',
-                prefix, tonumber(_data), rule.train.max_trains)
-            end
-          end
-        end
-
-        lua_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_len_cb, --callback
-          'LLEN', -- command
-          {prefix .. '_spam'}
-        )
-      end,
-      data)
-    end
-  end
-
-  -- First we need to get all anns stored in our Redis
-  lua_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    rule.redis,
-    nil,
-    false, -- is write
-    members_cb, --callback
-    'SMEMBERS', -- command
-    {gen_ann_prefix(rule, nil)} -- arguments
-  )
-
-  return rule.watch_interval
+      {
+        ann_key,
+        tostring(os.time()),
+        tostring(rule.watch_interval * 2),
+        rspamd_util.get_hostname()
+    })
 end
 
 -- This function loads new ann from Redis
@@ -808,7 +783,8 @@ end
 -- for some specific rule + some specific setting
 -- This function tries to load more fresh or more specific ANNs in lieu of
 -- the existing ones.
-local function process_existing_ann(rule, ev_base, set, profiles)
+-- Use this function to load ANNs as `callback` parameter for `check_anns` function
+local function process_existing_ann(_, ev_base, rule, set, profiles)
   local my_symbols = set.symbols
   local min_diff = math.huge
   local sel_elt
@@ -872,6 +848,60 @@ local function process_existing_ann(rule, ev_base, set, profiles)
   end
 end
 
+
+-- This function checks all profiles and selects if we can train our
+-- ANN. By our we mean that it has exactly the same symbols in profile.
+-- Use this function to train ANN as `callback` parameter for `check_anns` function
+local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
+  local my_symbols = set.symbols
+  local sel_elt
+
+  for _,elt in fun.iter(profiles) do
+    if elt and elt.symbols then
+      local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
+      -- Check distance
+      if dist == 0 then
+        sel_elt = elt
+        break
+      end
+    end
+  end
+
+  if sel_elt then
+    -- We have our ANN and that's train vectors, check if we can learn
+    local ann_key = sel_elt.redis_key
+
+    lua_util.debugm(N, rspamd_config, "check ANN %s", ann_key)
+    local redis_len_cb = function(err, data)
+      if err then
+        rspamd_logger.errx(rspamd_config,
+            'cannot get FANN trains %s from redis: %s', ann_key, err)
+      elseif data and type(data) == 'number' or type(data) == 'string' then
+        if tonumber(data) and tonumber(data) >= rule.train.max_trains then
+          rspamd_logger.infox(rspamd_config,
+              'need to learn ANN %s after %s learn vectors (%s required)',
+              ann_key, tonumber(data), rule.train.max_trains)
+          do_train_ann(worker, ev_base, rule, set, ann_key)
+        else
+          rspamd_logger.debugm(N, rspamd_config,
+              'no need to learn ANN %s %s learn vectors (%s required)',
+              ann_key, tonumber(data), rule.train.max_trains)
+        end
+      end
+    end
+
+    lua_redis.redis_make_request_taskless(ev_base,
+        rspamd_config,
+        rule.redis,
+        nil,
+        false, -- is write
+        redis_len_cb, --callback
+        'LLEN', -- command
+        {ann_key .. '_spam'}
+    )
+  end
+end
+
 -- Used to deserialise ANN element from a list
 local function load_ann_profile(element)
   local ucl = require "ucl"
@@ -888,14 +918,14 @@ local function load_ann_profile(element)
 end
 
 -- Function to check or load ANNs from Redis
-local function check_anns(rule, cfg, ev_base)
+local function check_anns(worker, rule, cfg, ev_base, process_callback)
   for _,set in pairs(rule.settings) do
     local function members_cb(err, data)
       if err then
         rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
             err)
       elseif type(data) == 'table' then
-        process_existing_ann(rule, ev_base, set, fun.map(load_ann_profile, data))
+        process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
       end
     end
 
@@ -1119,7 +1149,7 @@ for _,rule in pairs(settings.rules) do
   rspamd_config:add_on_load(function(cfg, ev_base, worker)
     rspamd_config:add_periodic(ev_base, 0.0,
         function(_, _)
-          return check_anns(rule, cfg, ev_base)
+          return check_anns(worker, cfg, ev_base, rule, process_existing_ann)
         end)
 
     if worker:is_primary_controller() then
@@ -1128,7 +1158,7 @@ for _,rule in pairs(settings.rules) do
           function(_, _)
             -- Clean old ANNs
             cleanup_anns(rule, cfg, ev_base)
-            return maybe_train_anns(rule, cfg, ev_base, worker)
+            return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann)
           end)
     end
   end)