]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Forget old ANN when max_usages is reached to avoid overtrain
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 16 Nov 2016 11:16:42 +0000 (11:16 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 16 Nov 2016 11:16:42 +0000 (11:16 +0000)
src/plugins/lua/fann_redis.lua

index cff8981ded2fd79544a54973107d339b5b316319..324b7ba299a45be1ddee9361af65e7c17fcfa738 100644 (file)
@@ -79,7 +79,9 @@ local redis_lua_script_maybe_load = [[
   local ver = 0
   local ret = redis.call('GET', KEYS[1] .. '_version')
   if ret then ver = tonumber(ret) end
-  if ver > tonumber(KEYS[2]) then return redis.call('GET', KEYS[1] .. '_data') end
+  if ver > tonumber(KEYS[2]) then
+    return {redis.call('GET', KEYS[1] .. '_data'), ret}
+  end
 
   return false
 ]]
@@ -137,7 +139,8 @@ redis_params = rspamd_parse_redis_server('fann_redis')
 
 local fann_prefix = 'RFANN'
 local max_trains = 1000
-local max_epoch = 100
+local max_epoch = 1000
+local max_usages = 10
 local use_settings = false
 local watch_interval = 60.0
 local mse = 0.0001
@@ -280,8 +283,16 @@ local function create_train_fann(n, id)
   end
 
   if fanns[id].fann then
-    fanns[id].fann_train = fanns[id].fann
-    fanns[id].fann = nil
+    if fanns[id].version % max_usages == 0 then
+      -- Forget last fann
+      rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', id,
+        fanns[id].version)
+      fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+      fanns[id].fann = nil
+    else
+      fanns[id].fann_train = fanns[id].fann
+      fanns[id].fann = nil
+    end
   else
     fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
     fanns[id].version = 0
@@ -289,24 +300,34 @@ local function create_train_fann(n, id)
 end
 
 local function load_or_invalidate_fann(data, id, ev_base)
-  local err,ann_data = rspamd_util.zstd_decompress(data)
+  local ver = data[2]
+  if not ver or not tonumber(ver) then
+    rspamd_logger.errx(rspamd_config, 'cannot get version for ann: %s', id)
+    return
+  end
+
+  local err,ann_data = rspamd_util.zstd_decompress(data[1])
   local ann
 
   if err or not ann_data then
-    rspamd_logger.errx(rspamd_config, 'cannot decompress ann: %s', err)
+    rspamd_logger.errx(rspamd_config, 'cannot decompress ann %s: %s', id, err)
+    return
   else
     ann = rspamd_fann.load_data(ann_data)
   end
 
   if is_fann_valid(ann) then
     fanns[id].fann = ann
-    rspamd_logger.infox(rspamd_config, 'loaded ann %s from redis', id)
+    rspamd_logger.infox(rspamd_config, 'loaded ann %s version %s from redis',
+      id, ver)
+    fanns[id].version = tonumber(ver)
   else
     local function redis_invalidate_cb(_err, _data)
       if _err then
         rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, _err)
       elseif type(_data) == 'string' then
         rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, _err)
+        fanns[id].version = 0
       end
     end
     -- Invalidate ANN
@@ -701,6 +722,12 @@ else
       if opts['train']['max_epoch'] then
         max_epoch = opts['train']['max_epoch']
       end
+      if opts['train']['max_usages'] then
+        max_usages = opts['train']['max_usages']
+      end
+      if opts['train']['mse'] then
+        mse = opts['train']['mse']
+      end
       local ret = cfg:register_worker_script("log_helper",
         function(score, req_score, results, cf, _id, extra, ev_base)
           -- fun.map (snd x) (fun.filter (fst x == module_id) extra)