]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Invalidate ANN if training data is incorrect
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 22 Nov 2016 13:13:36 +0000 (13:13 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 22 Nov 2016 13:13:36 +0000 (13:13 +0000)
src/plugins/lua/fann_redis.lua

index aa4efd4c634ef31f280788032dfe3ea2a8e6ad79..e2ac770dca98d0114e8d69397fbe1b4eb9617b9c 100644 (file)
@@ -103,6 +103,19 @@ local redis_lua_script_maybe_invalidate = [[
 ]]
 local redis_maybe_invalidate_sha = nil
 
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+local redis_lua_script_locked_invalidate = [[
+  redis.call('SET', KEYS[1] .. '_version', '0')
+  redis.call('DEL', KEYS[1] .. '_spam')
+  redis.call('DEL', KEYS[1] .. '_ham')
+  redis.call('DEL', KEYS[1] .. '_data')
+  redis.call('DEL', KEYS[1] .. '_locked')
+  return 1
+]]
+local redis_locked_invalidate_sha = nil
+
 -- Lua script to invalidate ANN from redis
 -- Uses the following keys
 -- key1 - prefix for keys
@@ -511,10 +524,32 @@ local function train_fann(_, ev_base, elt)
         create_train_fann(n, elt)
       end
 
-      learning_spawned = true
-      rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt)
-      fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
-        {max_epochs = max_epoch, desired_mse = mse})
+      if #inputs < max_trains / 2 then
+        -- Invalidate ANN as it is definitely invalid
+        local function redis_invalidate_cb(_err, _data)
+          if _err then
+            rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err)
+          elseif type(_data) == 'string' then
+            rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err)
+            fanns[id].version = 0
+          end
+        end
+        -- Invalidate ANN
+        rspamd_logger.infox('invalidate ANN %s: training data is invalid')
+        redis_make_request(ev_base,
+          rspamd_config,
+          nil,
+          true, -- is write
+          redis_invalidate_cb, --callback
+          'EVALSHA', -- command
+          {redis_locked_invalidate_sha, 1, gen_fann_prefix(id)}
+        )
+      else
+        learning_spawned = true
+        rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt)
+        fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
+          {max_epochs = max_epoch, desired_mse = mse})
+      end
     end
   end
 
@@ -686,6 +721,22 @@ local function load_scripts(cfg, ev_base, on_load_cb)
     {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
   )
 
+  local function locked_invalidate_sha_cb(err, data)
+    if err or not data or type(data) ~= 'string' then
+      rspamd_logger.errx(cfg, 'cannot save redis locked invalidate script: %s', err)
+    else
+      redis_locked_invalidate_sha = tostring(data)
+    end
+  end
+  redis_make_request(ev_base,
+    rspamd_config,
+    nil,
+    true, -- is write
+    locked_invalidate_sha_cb, --callback
+    'SCRIPT', -- command
+    {'LOAD', redis_lua_script_locked_invalidate} -- arguments
+  )
+
   local function maybe_lock_sha_cb(err, data)
     if err or not data or type(data) ~= 'string' then
       rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err)