]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Implement loading/invalidating
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 4 Nov 2016 15:27:40 +0000 (15:27 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 4 Nov 2016 15:27:40 +0000 (15:27 +0000)
src/plugins/lua/fann_scores.lua

index 0a238db291a405871b7ceb04e65b478c79b10846..3c46cda2f1a1951337f65283e8db6d02fb8cacd6 100644 (file)
@@ -55,9 +55,9 @@ local redis_lua_script_can_train = [[
   if ret then nham = tonumber(ret) end
 
   if KEYS[3] == 'spam' then
-    if nham + 1 >= nspam then return tostring(nspam) end
+    if nham + 1 >= nspam then return tostring(nspam + 1) end
   else
-    if nspam + 1 >= nham then return tostring(nham) end
+    if nspam + 1 >= nham then return tostring(nham + 1) end
   end
 
   return tostring(0)
@@ -80,12 +80,28 @@ local redis_lua_script_maybe_load = [[
 
   return false
 ]]
-local redis_fann_maybe_load_sha = nil
+local redis_maybe_load_sha = nil
+
+-- Lua script to invalidate ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+local redis_lua_script_maybe_invalidate = [[
+  local locked = redis.call('GET', KEYS[1] .. '_locked')
+  if locked then return false end
+  redis.call('SET', KEYS[1] .. '_locked', '1')
+  redis.call('SET', KEYS[1] .. '_version', '0')
+  redis.call('DEL', KEYS[1] .. '_spam')
+  redis.call('DEL', KEYS[1] .. '_ham')
+  redis.call('DEL', KEYS[1] .. '_data')
+  redis.call('DEL', KEYS[1] .. '_locked')
+  return 1
+]]
+local redis_maybe_invalidate_sha = nil
 
 local redis_params
 redis_params = rspamd_parse_redis_server('fann_scores')
 
-local fann_prefix = 'RF'
+local fann_prefix = 'RFANN'
 local max_trains = 1000
 local max_epoch = 100
 local use_settings = false
@@ -385,29 +401,25 @@ local function gen_fann_prefix(id)
   end
 end
 
-local function is_fann_valid(id)
-  if data[id].fann then
+local function is_fann_valid(ann)
+  if ann then
     local n = rspamd_config:get_symbols_count() + count_metatokens()
 
-    if n ~= data[id].fann:get_inputs() then
+    if n ~= ann:get_inputs() then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
-      ' is found in the cache', data[id].fann:get_inputs(), n)
-      data[id].fann = nil
+      ' is found in the cache', ann:get_inputs(), n)
+      return false
     end
-    local layers = data[id].fann:get_layers()
+    local layers = ann:get_layers()
 
     if not layers or #layers ~= 5 then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
         #layers)
-      data[id].fann = nil
+      return false
     end
-  end
 
-  if data[id].fann then
     return true
   end
-
-  return false
 end
 
 local function fann_scores_filter(task)
@@ -448,6 +460,39 @@ local function create_train_fann(n, id)
   data[id].epoch = 0
 end
 
+local function load_or_invalidate_fann(data, id, ev_base)
+  local err,ann_data = rspamd_util.zstd_decompress(data)
+  local ann
+
+  if err or not ann_data then
+    rspamd_logger.errx('cannot decompress ann: %s', err)
+  else
+    ann = rspamd_fann.load_data(ann_data)
+  end
+
+  if is_fann_valid(ann) then
+    data[id].fann = ann
+  else
+    local function redis_invalidate_cb(err, data)
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err)
+      elseif type(data) == 'string' then
+        rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err)
+      end
+    end
+    -- Invalidate ANN
+    rspamd_logger.infox('invalidate ANN %s')
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      redis_invalidate_cb, --callback
+      'EVALSHA', -- command
+      {redis_maybe_invalidate_sha, 1, fann_prefix .. id}
+    )
+  end
+end
+
 local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base)
   local fname = gen_fann_prefix(id)
 
@@ -468,7 +513,14 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
     local k
     if learn_spam then k = 'spam' else k = 'ham' end
 
+    local function learn_vec_cb(err, data)
+      if err then
+        rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err)
+      end
+    end
+
     local function can_train_cb(err, data)
+      rspamd_logger.errx('data: %s, err: %s', data, err)
       if not err and tonumber(data) > 0 then
         local learn_data = symbols_to_fann_vector(
           map(function(r) return r[1] end, results),
@@ -476,13 +528,13 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
         )
         -- Add filtered meta tokens
         each(function(e) table.insert(learn_data, e) end, extra)
-        local str = table.concat(learn_data, ';')
+        local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
 
         redis_make_request(ev_base,
           rspamd_config,
           nil,
           true, -- is write
-          learn_cb, --callback
+          learn_vec_cb, --callback
           'LPUSH', -- command
           {fname .. '_' .. k, str} -- arguments
         )
@@ -510,25 +562,11 @@ local function check_fanns(cfg, ev_base)
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
     elseif type(data) == 'table' then
       each(function(i, elt)
-        local redis_load_cb = function(err, data)
-          if err then
-            rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
-          elseif type(data) == 'string' then
-            --load_fann(data, elt)
-          end
-        end
         local redis_update_cb = function(err, data)
           if err then
             rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
-          elseif data then
-            redis_make_request(ev_base,
-              rspamd_config,
-              nil,
-              false, -- is write
-              redis_load_cb, --callback
-              'GET', -- command
-              {fann_prefix, fann_prefix .. elt .. '_data'} -- arguments
-            )
+          elseif data and type(data) == 'string' then
+            load_or_invalidate_fann(data, elt, ev_base)
           end
         end
 
@@ -545,14 +583,14 @@ local function check_fanns(cfg, ev_base)
           false, -- is write
           redis_update_cb, --callback
           'EVALSHA', -- command
-          {redis_fann_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)}
+          {redis_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)}
         )
       end,
       data)
     end
   end
 
-  if not redis_fann_maybe_load_sha then
+  if not redis_maybe_load_sha then
     -- Plan new event early
     return 1.0
   end
@@ -663,7 +701,7 @@ else
       if err or not data or type(data) ~= 'string' then
         rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err)
       else
-        redis_fann_maybe_load_sha = tostring(data)
+        redis_maybe_load_sha = tostring(data)
 
         rspamd_config:add_periodic(ev_base, 0.0,
           function(cfg, ev_base)
@@ -679,5 +717,21 @@ else
       'SCRIPT', -- command
       {'LOAD', redis_lua_script_maybe_load} -- arguments
     )
+
+    local function maybe_invalidate_sha_cb(err, data)
+      if err or not data or type(data) ~= 'string' then
+        rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err)
+      else
+        redis_maybe_invalidate_sha = tostring(data)
+      end
+    end
+    redis_make_request(ev_base,
+      rspamd_config,
+      nil,
+      true, -- is write
+      maybe_invalidate_sha_cb, --callback
+      'SCRIPT', -- command
+      {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
+    )
   end)
 end