]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Fix ANN checks
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 14:56:28 +0000 (15:56 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 14:56:28 +0000 (15:56 +0100)
src/plugins/lua/fann_redis.lua

index 2473fb290a92c8137fb1611a47c7d8fdd414584f..f7ec65d30ce0a6e297a9c389528f424d2ac2433f 100644 (file)
@@ -43,6 +43,7 @@ local default_options = {
     max_trains = 1000,
     max_epoch = 1000,
     max_usages = 10,
+    max_iterations = 25, -- Torch style
     mse = 0.001,
     autotrain = true,
   },
@@ -331,19 +332,7 @@ local function is_fann_valid(rule, prefix, ann)
         meta_functions.rspamd_count_metatokens()
 
     if torch then
-      local nlayers = #ann
-      if nlayers ~= rule.nlayers then
-        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
-          prefix, nlayers)
-        return false
-      end
-
-      local inp = ann:get(1):nElement()
-      if n ~= inp then
-        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
-            ' is found in the cache', prefix, inp, n)
-        return false
-      end
+      return true
     else
       if n ~= ann:get_inputs() then
         rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
@@ -364,12 +353,13 @@ local function is_fann_valid(rule, prefix, ann)
 end
 
 local function fann_scores_filter(task)
-  for _,rule in ipairs(settings.rules) do
-    local id = rule.prefix .. '0'
+
+  for _,rule in pairs(settings.rules) do
+    local id = '0'
     if rule.use_settings then
      local sid = task:get_settings_id()
      if sid then
-      id = rule.prefix .. tostring(sid)
+      id = tostring(sid)
      end
     end
     if rule.per_user then
@@ -481,6 +471,7 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
   end
 
   if is_fann_valid(rule, prefix, ann) then
+    if not fanns[id] then fanns[id] = {} end
     fanns[id].fann = ann
     rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
       prefix, ver)
@@ -627,6 +618,8 @@ local function train_fann(rule, _, ev_base, elt, worker)
       if string.match(err, 'NOSCRIPT') then
         load_scripts(rspamd_config, ev_base, nil)
       end
+    else
+      rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix)
     end
   end
 
@@ -666,7 +659,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
         true, -- is write
         redis_save_cb, --callback
         'EVALSHA', -- command
-        {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+        {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
       )
     end
   end
@@ -686,8 +679,8 @@ local function train_fann(rule, _, ev_base, elt, worker)
         {prefix .. '_locked'}
       )
     else
-      rspamd_logger.infox(rspamd_config, 'trained ANN %s',
-        prefix)
+      rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
+        prefix, #data)
       local ann_data
       local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
       ann_data = rspamd_util.zstd_compress(f:storage():string())
@@ -703,7 +696,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
         true, -- is write
         redis_save_cb, --callback
         'EVALSHA', -- command
-        {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+        {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
       )
     end
   end
@@ -780,6 +773,8 @@ local function train_fann(rule, _, ev_base, elt, worker)
             local trainer = nn.StochasticGradient(fanns[elt].fann_train,
               criterion)
             trainer.learning_rate = 0.01
+            trainer.verbose = false
+            trainer.maxIteration = rule.train.max_iterations
             trainer.hookIteration = function(self, iteration, currentError)
               rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
                   iteration, currentError)
@@ -980,18 +975,23 @@ end
 local function check_fanns(rule, _, ev_base)
   local function members_cb(err, data)
     if err then
-      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
+      rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
+        err)
     elseif type(data) == 'table' then
       fun.each(function(elt)
         elt = tostring(elt)
         local redis_update_cb = function(_err, _data)
           if _err then
-            rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, _err)
+            rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
+              elt, _err)
             if string.match(_err, 'NOSCRIPT') then
               load_scripts(rspamd_config, ev_base, nil)
             end
           elseif _data and type(_data) == 'table' then
             load_or_invalidate_fann(rule, _data, elt, ev_base)
+          else
+            rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s',
+              type(_data), elt)
           end
         end