]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Add redis storage feature to fann_redis
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Nov 2016 20:48:53 +0000 (21:48 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 5 Nov 2016 20:48:53 +0000 (21:48 +0100)
src/plugins/lua/fann_redis.lua

index f55454bf643dae2ecc0ace53b34d8ffece4a7610..aabf465ce6e7ba7cb487584566d826bb38e2f99d 100644 (file)
@@ -30,9 +30,7 @@ local module_log_id = 0x100
 -- ANNs indexed by settings id
 local data = {
   ['0'] = {
-    fann_mtime = 0,
-    ntrains = 0,
-    epoch = 0,
+    version = 0,
   }
 }
 
@@ -108,6 +106,20 @@ local redis_lua_script_maybe_lock = [[
 ]]
 local redis_maybe_lock_sha = nil
 
+-- Lua script to save and unlock ANN in redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - compressed ANN
+local redis_lua_script_save_unlock = [[
+  redis.call('INCRBY', KEYS[1] .. '_version', '1')
+  redis.call('DEL', KEYS[1] .. '_spam')
+  redis.call('DEL', KEYS[1] .. '_ham')
+  redis.call('SET', KEYS[1] .. '_data', KEYS[2])
+  redis.call('DEL', KEYS[1] .. '_locked')
+  return 1
+]]
+local redis_save_unlock_sha = nil
+
 local redis_params
 redis_params = rspamd_parse_redis_server('fann_redis')
 
@@ -116,6 +128,7 @@ local max_trains = 1000
 local max_epoch = 100
 local use_settings = false
 local watch_interval = 60.0
+local mse = 0.0001
 
 local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
   if not ev_base or not redis_params or not callback or not command then
@@ -251,8 +264,7 @@ end
 
 local function create_train_fann(n, id)
   data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
-  data[id].ntrains = 0
-  data[id].epoch = 0
+  data[id].version = 0
 end
 
 local function load_or_invalidate_fann(data, id, ev_base)
@@ -361,6 +373,41 @@ local function train_fann(cfg, ev_base, elt)
     end
   end
 
+  local function redis_save_unlock_sha(err, data)
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
+        fann_prefix .. elt, err)
+    end
+  end
+
+  local function ann_trained(errcode, errmsg, train_mse)
+    if errcode ~= 0 then
+      rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
+        fann_prefix .. elt, errmsg)
+      redis_make_request(ev_base,
+        rspamd_config,
+        nil,
+        false, -- is write
+        redis_unlock_cb, --callback
+        'DEL', -- command
+        {fann_prefix .. elt .. '_lock'}
+      )
+    else
+      rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
+        fann_prefix .. elt, train_mse)
+      local ann_data = rspamd_util.zstd_compress(data[elt].fann:data())
+      data[elt].version = data[elt].version + 1
+      redis_make_request(ev_base,
+        rspamd_config,
+        nil,
+        true, -- is write
+        redis_save_cb, --callback
+        'EVALSHA', -- command
+        {redis_save_unlock_sha, '2', fann_prefix .. elt, ann_data}
+      )
+    end
+  end
+
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
@@ -375,8 +422,8 @@ local function train_fann(cfg, ev_base, elt)
       )
     else
       -- Decompress and convert to numbers each training vector
-      ham_elts = map(function(i, elt)
-        local str = tostring(rspamd_util.zstd_decompress(elt))
+      ham_elts = map(function(i, tok)
+        local str = tostring(rspamd_util.zstd_decompress(tok))
         return map(tonumber, rspamd_str_split(str, ';'))
       end, data)
 
@@ -384,17 +431,24 @@ local function train_fann(cfg, ev_base, elt)
       local inputs = {}
       local outputs = {}
 
-      each(function(i, elt)
-        table.insert(inputs, totable(elt))
+      each(function(i, sample)
+        table.insert(inputs, totable(sample))
         table.insert(outputs, 1.0)
       end, spam_elts)
-      each(function(i, elt)
-        table.insert(inputs, totable(elt))
+      each(function(i, sample)
+        table.insert(inputs, totable(sample))
         table.insert(outputs, -1.0)
       end, spam_elts)
 
       -- Now we can train fann
+      local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
+      if not data[elt].fann then
+        -- Create fann if it does not exist
+        create_train_fann(n, elt)
+      end
 
+      data[elt].fann:train_threaded(inputs, outputs, ann_trained, ev_base,
+        {max_epochs = max_epoch, desired_mse = mse})
     end
   end
 
@@ -412,8 +466,8 @@ local function train_fann(cfg, ev_base, elt)
       )
     else
       -- Decompress and convert to numbers each training vector
-      spam_elts = map(function(i, elt)
-        local str = tostring(rspamd_util.zstd_decompress(elt))
+      spam_elts = map(function(i, tok)
+        local str = tostring(rspamd_util.zstd_decompress(tok))
         return map(tonumber, rspamd_str_split(str, ';'))
       end, data)
       redis_make_request(ev_base,
@@ -700,6 +754,22 @@ else
       {'LOAD', redis_lua_script_maybe_lock} -- arguments
     )
 
+    local function save_unlock_sha_cb(err, data)
+      if err or not data or type(data) ~= 'string' then
+        rspamd_logger.errx(cfg, 'cannot save redis save script: %s', err)
+      else
+        redis_save_unlock_sha = tostring(data)
+      end
+    end
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      save_unlock_sha_cb, --callback
+      'SCRIPT', -- command
+      {'LOAD', redis_lua_script_save_unlock} -- arguments
+    )
+
     if worker:get_name() == 'normal' then
       -- We also want to train neural nets when they have enough data
       rspamd_config:add_periodic(ev_base, 0.0,