]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Multiple fixes for fann module
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 21 Feb 2017 16:00:40 +0000 (16:00 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 21 Feb 2017 16:00:40 +0000 (16:00 +0000)
src/plugins/lua/fann_redis.lua

index a9c6a41cb74baa9f736e49ed5d2c383881a73a69..a0953b00cd932083d393c60299d6e97ac1285a63 100644 (file)
@@ -39,6 +39,7 @@ local fanns = {
 -- key1 - prefix for fann
 -- key2 - fann suffix (settings id)
 -- key3 - spam or ham
+-- key4 - maximum trains
 -- returns 1 or 0: 1 - allow learn, 0 - not allow learn
 local redis_lua_script_can_train = [[
   local prefix = KEYS[1] .. KEYS[2]
@@ -46,6 +47,7 @@ local redis_lua_script_can_train = [[
   if locked then return 0 end
   local nspam = 0
   local nham = 0
+  local lim = tonumber(KEYS[4])
 
   local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2])
   if not exists or exists == 0 then
@@ -58,9 +60,9 @@ local redis_lua_script_can_train = [[
   if ret then nham = tonumber(ret) end
 
   if KEYS[3] == 'spam' then
-    if nham + 1 >= nspam then return tostring(nspam + 1) end
+    if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) end
   else
-    if nspam + 1 >= nham then return tostring(nham + 1) end
+    if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) end
   end
 
   return tostring(0)
@@ -344,20 +346,20 @@ local function gen_fann_prefix(id)
   end
 end
 
-local function is_fann_valid(ann)
+local function is_fann_valid(prefix, ann)
   if ann then
     local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
 
     if n ~= ann:get_inputs() then
-      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
-      ' is found in the cache', ann:get_inputs(), n)
+      rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+      ' is found in the cache', prefix, ann:get_inputs(), n)
       return false
     end
     local layers = ann:get_layers()
 
     if not layers or #layers ~= nlayers then
-      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
-        #layers)
+      rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+        prefix, #layers)
       return false
     end
 
@@ -399,6 +401,7 @@ end
 
 local function create_train_fann(n, id)
   id = tostring(id)
+  local prefix = gen_fann_prefix(id)
   if not fanns[id] then
     fanns[id] = {}
   end
@@ -406,13 +409,13 @@ local function create_train_fann(n, id)
   if fanns[id].fann then
     if n ~= fanns[id].fann:get_inputs() or
       (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then
-      rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', id,
+      rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix,
         fanns[id].version)
       fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
       fanns[id].fann = nil
     elseif fanns[id].version % max_usages == 0 then
       -- Forget last fann
-      rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', id,
+      rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
         fanns[id].version)
       fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
     else
@@ -426,8 +429,10 @@ end
 
 local function load_or_invalidate_fann(data, id, ev_base)
   local ver = data[2]
+  local prefix = gen_fann_prefix(id)
+
   if not ver or not tonumber(ver) then
-    rspamd_logger.errx(rspamd_config, 'cannot get version for ann: %s', id)
+    rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
     return
   end
 
@@ -435,38 +440,38 @@ local function load_or_invalidate_fann(data, id, ev_base)
   local ann
 
   if err or not ann_data then
-    rspamd_logger.errx(rspamd_config, 'cannot decompress ann %s: %s', id, err)
+    rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
     return
   else
     ann = rspamd_fann.load_data(ann_data)
   end
 
-  if is_fann_valid(ann) then
+  if is_fann_valid(prefix, ann) then
     fanns[id].fann = ann
-    rspamd_logger.infox(rspamd_config, 'loaded ann %s version %s from redis',
-      id, ver)
+    rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
+      prefix, ver)
     fanns[id].version = tonumber(ver)
   else
     local function redis_invalidate_cb(_err, _data)
       if _err then
-        rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err)
+        rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
         if string.match(_err, 'NOSCRIPT') then
           load_scripts(rspamd_config, ev_base, nil)
         end
       elseif type(_data) == 'string' then
-        rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err)
+        rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
         fanns[id].version = 0
       end
     end
     -- Invalidate ANN
-    rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', id)
+    rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
     redis_make_request(ev_base,
       rspamd_config,
       nil,
       true, -- is write
       redis_invalidate_cb, --callback
       'EVALSHA', -- command
-      {redis_maybe_invalidate_sha, 1, gen_fann_prefix(id)}
+      {redis_maybe_invalidate_sha, 1, prefix}
     )
   end
 end
@@ -493,7 +498,7 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
 
     local function learn_vec_cb(err)
       if err then
-        rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err)
+        rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err)
       end
     end
 
@@ -517,7 +522,7 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
         )
       else
         if err then
-          rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err)
+          rspamd_logger.errx(rspamd_config, 'cannot check if we can train %s: %s', fname, err)
           if string.match(err, 'NOSCRIPT') then
             load_scripts(rspamd_config, ev_base, nil)
           end
@@ -531,7 +536,7 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
       true, -- is write
       can_train_cb, --callback
       'EVALSHA', -- command
-      {redis_can_train_sha, '3', gen_fann_prefix(nil), suffix, k} -- arguments
+      {redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments
     )
   end
 end
@@ -540,25 +545,26 @@ local function train_fann(_, ev_base, elt)
   local spam_elts = {}
   local ham_elts = {}
   elt = tostring(elt)
+  local prefix = gen_fann_prefix(elt)
 
   local function redis_unlock_cb(err)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s from redis: %s',
-        gen_fann_prefix(elt), err)
+        prefix, err)
     end
   end
 
   local function redis_save_cb(err)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
-        gen_fann_prefix(elt), err)
+        prefix, err)
       redis_make_request(ev_base,
         rspamd_config,
         nil,
         false, -- is write
         redis_unlock_cb, --callback
         'DEL', -- command
-        {gen_fann_prefix(elt) .. '_locked'}
+        {prefix .. '_locked'}
       )
       if string.match(err, 'NOSCRIPT') then
         load_scripts(rspamd_config, ev_base, nil)
@@ -570,18 +576,18 @@ local function train_fann(_, ev_base, elt)
     learning_spawned = false
     if errcode ~= 0 then
       rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
-        gen_fann_prefix(elt), errmsg)
+        prefix, errmsg)
       redis_make_request(ev_base,
         rspamd_config,
         nil,
         true, -- is write
         redis_unlock_cb, --callback
         'DEL', -- command
-        {gen_fann_prefix(elt) .. '_locked'}
+        {prefix .. '_locked'}
       )
     else
       rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
-        gen_fann_prefix(elt), train_mse)
+        prefix, train_mse)
       local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
       fanns[elt].version = fanns[elt].version + 1
       fanns[elt].fann = fanns[elt].fann_train
@@ -592,7 +598,7 @@ local function train_fann(_, ev_base, elt)
         true, -- is write
         redis_save_cb, --callback
         'EVALSHA', -- command
-        {redis_save_unlock_sha, '2', gen_fann_prefix(elt), ann_data}
+        {redis_save_unlock_sha, '2', prefix, ann_data}
       )
     end
   end
@@ -600,14 +606,14 @@ local function train_fann(_, ev_base, elt)
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
-        gen_fann_prefix(elt), err)
+        prefix, err)
       redis_make_request(ev_base,
         rspamd_config,
         nil,
         true, -- is write
         redis_unlock_cb, --callback
         'DEL', -- command
-        {gen_fann_prefix(elt) .. '_locked'}
+        {prefix .. '_locked'}
       )
     else
       -- Decompress and convert to numbers each training vector
@@ -643,25 +649,25 @@ local function train_fann(_, ev_base, elt)
         -- Invalidate ANN as it is definitely invalid
         local function redis_invalidate_cb(_err, _data)
           if _err then
-            rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', elt, _err)
+            rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
           elseif type(_data) == 'string' then
-            rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', elt, _err)
+            rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
             fanns[elt].version = 0
           end
         end
         -- Invalidate ANN
-        rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', elt)
+        rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
         redis_make_request(ev_base,
           rspamd_config,
           nil,
           true, -- is write
           redis_invalidate_cb, --callback
           'EVALSHA', -- command
-          {redis_locked_invalidate_sha, 1, gen_fann_prefix(elt)}
+          {redis_locked_invalidate_sha, 1, prefix}
         )
       else
         learning_spawned = true
-        rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt)
+        rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
         fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
           {max_epochs = max_epoch, desired_mse = mse})
       end
@@ -671,14 +677,14 @@ local function train_fann(_, ev_base, elt)
   local function redis_spam_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
-        gen_fann_prefix(elt), err)
+        prefix, err)
       redis_make_request(ev_base,
         rspamd_config,
         nil,
         true, -- is write
         redis_unlock_cb, --callback
         'DEL', -- command
-        {gen_fann_prefix(elt) .. '_locked'}
+        {prefix .. '_locked'}
       )
     else
       -- Decompress and convert to numbers each training vector
@@ -692,7 +698,7 @@ local function train_fann(_, ev_base, elt)
         false, -- is write
         redis_ham_cb, --callback
         'LRANGE', -- command
-        {gen_fann_prefix(elt) .. '_ham', '0', '-1'}
+        {prefix .. '_ham', '0', '-1'}
       )
     end
   end
@@ -700,7 +706,7 @@ local function train_fann(_, ev_base, elt)
   local function redis_lock_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
-        gen_fann_prefix(elt), err)
+        prefix, err)
       if string.match(err, 'NOSCRIPT') then
         load_scripts(rspamd_config, ev_base, nil)
       end
@@ -712,7 +718,7 @@ local function train_fann(_, ev_base, elt)
         false, -- is write
         redis_spam_cb, --callback
         'LRANGE', -- command
-        {gen_fann_prefix(elt) .. '_spam', '0', '-1'}
+        {prefix .. '_spam', '0', '-1'}
       )
 
       rspamd_config:add_periodic(ev_base, 30.0,
@@ -720,10 +726,10 @@ local function train_fann(_, ev_base, elt)
           local function redis_lock_extend_cb(_err, _)
             if _err then
               rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
-                gen_fann_prefix(elt), _err)
+                prefix, _err)
             else
               rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
-                gen_fann_prefix(elt))
+                prefix)
             end
           end
           if learning_spawned then
@@ -733,7 +739,7 @@ local function train_fann(_, ev_base, elt)
               true, -- is write
               redis_lock_extend_cb, --callback
               'INCRBY', -- command
-              {gen_fann_prefix(elt) .. '_locked', '30'}
+              {prefix .. '_locked', '30'}
             )
           else
             return false -- do not plan any more updates
@@ -742,13 +748,13 @@ local function train_fann(_, ev_base, elt)
           return true
         end
       )
-      rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', elt)
+      rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', prefix)
     else
-      rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', elt)
+      rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix)
     end
   end
   if learning_spawned then
-    rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN')
+    rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
     return
   end
   redis_make_request(ev_base,
@@ -757,7 +763,7 @@ local function train_fann(_, ev_base, elt)
     true, -- is write
     redis_lock_cb, --callback
     'EVALSHA', -- command
-    {redis_maybe_lock_sha, '4', gen_fann_prefix(elt), tostring(os.time()),
+    {redis_maybe_lock_sha, '4', prefix, tostring(os.time()),
       tostring(lock_expire), rspamd_util.get_hostname()}
   )
 end
@@ -769,13 +775,14 @@ local function maybe_train_fanns(cfg, ev_base)
     elseif type(data) == 'table' then
       fun.each(function(elt)
         elt = tostring(elt)
+        local prefix = gen_fann_prefix(elt)
         local redis_len_cb = function(_err, _data)
           if _err then
-            rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, _err)
+            rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', prefix, _err)
           elseif _data and type(_data) == 'number' or type(_data) == 'string' then
             if tonumber(_data) and tonumber(_data) > max_trains then
               rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
-                elt, tonumber(_data), max_trains)
+                prefix, tonumber(_data), max_trains)
               train_fann(cfg, ev_base, elt)
             end
           end
@@ -787,7 +794,7 @@ local function maybe_train_fanns(cfg, ev_base)
           false, -- is write
           redis_len_cb, --callback
           'LLEN', -- command
-          {gen_fann_prefix(elt) .. '_spam'}
+          {prefix .. '_spam'}
         )
       end,
       data)