]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Further PCA fixes
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 28 Aug 2020 16:40:02 +0000 (17:40 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 28 Aug 2020 16:40:02 +0000 (17:40 +0100)
src/plugins/lua/neural.lua

index 225c9895bbe1cdeb5bb7586c7bf3b1f0b3d28af7..d7410b2250c404415e4d2f99342e08be86103149 100644 (file)
@@ -114,7 +114,10 @@ end
 local redis_lua_script_vectors_len = [[
   local prefix = KEYS[1]
   local locked = redis.call('HGET', prefix, 'lock')
-  if locked then return false end
+  if locked then
+    local host = redis.call('HGET', prefix, 'hostname')
+    return string.format('%s:%s', hostname, locked)
+  end
   local nspam = 0
   local nham = 0
 
@@ -547,10 +550,10 @@ local function ann_push_task_result(rule, task, verdict, score, set)
         if err then
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
               rule.prefix, set.name, err)
-        elseif type(data) == 'userdata' then
+        elseif type(data) == 'string' then
           -- nil return value
-          rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning",
-              learn_type, rule.prefix, set.name, set.ann.redis_key)
+          rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
+              learn_type, rule.prefix, set.name, set.ann.redis_key, data)
         else
           rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
               'please remove this key from Redis manually if you perform upgrade from the previous version',
@@ -647,6 +650,7 @@ local function fill_scatter(inputs)
   inputs = rspamd_tensor.fromtable(inputs):transpose()
 
   local meanv = inputs:mean()
+  lua_util.debugm(N, 'means: %s', meanv)
 
   for i=1,nsamples do
     local col = rspamd_tensor.new(1, #inputs)
@@ -662,6 +666,8 @@ local function fill_scatter(inputs)
     end
   end
 
+  lua_util.debugm(N, 'scatter matrix: %s', scatter_matrix)
+
   return scatter_matrix
 end
 
@@ -1004,7 +1010,7 @@ local function do_train_ann(worker, ev_base, rule, set, ann_key)
       {
         ann_key,
         tostring(os.time()),
-        tostring(rule.watch_interval * 2),
+        tostring(math.max(10.0, rule.watch_interval * 2)),
         rspamd_util.get_hostname()
     })
 end
@@ -1062,7 +1068,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
                   {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
               )
               rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                  rule.prefix, set.name, ann_key, #ann_data, profile.version)
+                  rule.prefix, set.name, ann_key, #data[1], profile.version)
             else
               rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
                   rule.prefix, set.name, ann_key)
@@ -1079,6 +1085,8 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             if rule.max_inputs then
               -- We can use PCA
               set.ann.pca = rspamd_tensor.load(pca_data)
+              rspamd_logger.infox(rspamd_config, 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+                  rule.prefix, set.name, ann_key, #data[2], profile.version)
             else
               -- no need in pca, why is it there?
               rspamd_logger.warnx(rspamd_config, 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',