]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] More fixes to fann_redis
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 15 Nov 2016 15:48:36 +0000 (15:48 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 15 Nov 2016 15:48:47 +0000 (15:48 +0000)
src/plugins/lua/fann_redis.lua

index 08216fe5ebecbe18615f45cab736afc7720f5a29..ab4d74afade2e5d28dca11b8a8cf1418c45b643a 100644 (file)
@@ -404,6 +404,7 @@ local function train_fann(_, ev_base, elt)
   end
 
   local function ann_trained(errcode, errmsg, train_mse)
+    learning_spawned = false
     if errcode ~= 0 then
       rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
         fann_prefix .. elt, errmsg)
@@ -418,7 +419,6 @@ local function train_fann(_, ev_base, elt)
     else
       rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
         fann_prefix .. elt, train_mse)
-      learning_spawned = false
       local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
       fanns[elt].version = fanns[elt].version + 1
       fanns[elt].fann = fanns[elt].fann_train
@@ -457,10 +457,10 @@ local function train_fann(_, ev_base, elt)
       local inputs = {}
       local outputs = {}
 
-      fun.each(function(sample)
-        table.insert(inputs, fun.totable(sample[1]))
+      fun.each(function(spam_sample, ham_sample)
+        table.insert(inputs, fun.totable(spam_sample))
         table.insert(outputs, {1.0})
-        table.insert(inputs, fun.totable(sample[2]))
+        table.insert(inputs, fun.totable(ham_sample))
         table.insert(outputs, {-1.0})
       end, fun.zip(spam_elts, ham_elts))
 
@@ -472,6 +472,7 @@ local function train_fann(_, ev_base, elt)
       end
 
       learning_spawned = true
+      rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt)
       fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
         {max_epochs = max_epoch, desired_mse = mse})
     end
@@ -520,9 +521,15 @@ local function train_fann(_, ev_base, elt)
         'LRANGE', -- command
         {fann_prefix .. elt .. '_spam', '0', '-1'}
       )
+      rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', elt)
+    else
+      rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', elt)
     end
   end
-  if learning_spawned then return end
+  if learning_spawned then
+    rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN')
+    return
+  end
   redis_make_request(ev_base,
     rspamd_config,
     nil,
@@ -547,7 +554,7 @@ local function maybe_train_fanns(cfg, ev_base)
           elseif _data and type(_data) == 'number' or type(_data) == 'string' then
             if tonumber(_data) and tonumber(_data) > max_trains then
               rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
-                tonumber(_data), max_trains)
+                elt, tonumber(_data), max_trains)
               train_fann(cfg, ev_base, elt)
             end
           end