]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] More issues in fann_redis
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 15 Nov 2016 14:59:55 +0000 (14:59 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 15 Nov 2016 14:59:55 +0000 (14:59 +0000)
src/plugins/lua/fann_redis.lua

index 361d82303efc4f72bad6dde8c44176e4c73d8f30..ad8a0f79af7cd8620205c96a0ebc9f2ae1683006 100644 (file)
@@ -105,9 +105,16 @@ local redis_maybe_invalidate_sha = nil
 -- Lua script to invalidate ANN from redis
 -- Uses the following keys
 -- key1 - prefix for keys
+-- key2 - current time
+-- key3 - key expire
 local redis_lua_script_maybe_lock = [[
   local locked = redis.call('GET', KEYS[1] .. '_locked')
-  if locked then return false end
+  if locked then
+    if tonumber(KEYS[2]) < tonumber(locked) then
+      return false
+    end
+  end
+  redis.call('SET', KEYS[1] .. '_locked', tostring(tonumber(KEYS[2]) + tonumber(KEYS[3])))
   return 1
 ]]
 local redis_maybe_lock_sha = nil
@@ -136,6 +143,8 @@ local use_settings = false
 local watch_interval = 60.0
 local mse = 0.0001
 local nlayers = 4
+local lock_expire = 600
+local learning_spawned = false
 
 local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
   if not ev_base or not redis_params or not callback or not command then
@@ -285,13 +294,14 @@ local function load_or_invalidate_fann(data, id, ev_base)
   local ann
 
   if err or not ann_data then
-    rspamd_logger.errx('cannot decompress ann: %s', err)
+    rspamd_logger.errx(rspamd_config, 'cannot decompress ann: %s', err)
   else
     ann = rspamd_fann.load_data(ann_data)
   end
 
   if is_fann_valid(ann) then
     fanns[id].fann = ann
+    rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id)
   else
     local function redis_invalidate_cb(err, data)
       if err then
@@ -387,7 +397,7 @@ local function train_fann(cfg, ev_base, elt)
     end
   end
 
-  local function redis_save_unlock_sha(err, data)
+  local function redis_save_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
         fann_prefix .. elt, err)
@@ -409,7 +419,8 @@ local function train_fann(cfg, ev_base, elt)
     else
       rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
         fann_prefix .. elt, train_mse)
-      local ann_data = rspamd_util.zstd_compress(data[elt].fann:data())
+      learning_spawned = false
+      local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
       fanns[elt].version = fanns[elt].version + 1
       fanns[elt].fann = fanns[elt].fann_train
       fanns[elt].fann_train = nil
@@ -448,13 +459,11 @@ local function train_fann(cfg, ev_base, elt)
       local outputs = {}
 
       each(function(sample)
-        table.insert(inputs, totable(sample))
+        table.insert(inputs, totable(sample[1]))
         table.insert(outputs, {1.0})
-      end, spam_elts)
-      each(function(sample)
-        table.insert(inputs, totable(sample))
+        table.insert(inputs, totable(sample[2]))
         table.insert(outputs, {-1.0})
-      end, ham_elts)
+      end, zip(spam_elts, ham_elts))
 
       -- Now we can train fann
       local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
@@ -463,6 +472,7 @@ local function train_fann(cfg, ev_base, elt)
         create_train_fann(n, elt)
       end
 
+      learning_spawned = true
       fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
         {max_epochs = max_epoch, desired_mse = mse})
     end
@@ -513,13 +523,15 @@ local function train_fann(cfg, ev_base, elt)
       )
     end
   end
+  if learning_spawned then return end
   redis_make_request(ev_base,
     rspamd_config,
     nil,
     true, -- is write
     redis_lock_cb, --callback
     'EVALSHA', -- command
-    {redis_maybe_lock_sha, '1', fann_prefix .. elt}
+    {redis_maybe_lock_sha, '3', fann_prefix .. elt, tostring(os.time()),
+      tostring(lock_expire)}
   )
 end
 
@@ -535,6 +547,8 @@ local function maybe_train_fanns(cfg, ev_base)
             rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
           elseif data and type(data) == 'number' or type(data) == 'string' then
             if tonumber(data) and tonumber(data) > max_trains then
+              rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
+                tonumber(data), max_trains)
               train_fann(cfg, ev_base, elt)
             end
           end