]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Implement PCA in learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 21:00:38 +0000 (22:00 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 21:00:38 +0000 (22:00 +0100)
src/plugins/lua/neural.lua

index 91caa8e07a28fe698fe2694d177d81bdbecd3e33..5b4ff8b3be04bf8b39eddafdb4aba9051a0d33b3 100644 (file)
@@ -639,10 +639,17 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
 end
 
 -- This is an utility function for PCA training
-local function fill_scatter(inputs, meanv)
+local function fill_scatter(inputs)
   local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
   local row_len = #inputs[1]
 
+  if type(inputs) == 'table' then
+    -- Convert to a tensor
+    inputs = rspamd_tensor.fromtable(inputs)
+  end
+
+  local meanv = inputs:mean()
+
   for i=1,row_len do
     local col = rspamd_tensor.new(1, #inputs)
     for j=1,#inputs do
@@ -663,8 +670,7 @@ end
 -- This function takes all inputs, applies PCA transformation and returns the final
 -- PCA matrix as rspamd_tensor
 local function learn_pca(inputs, max_inputs)
-  local meanv = inputs:mean()
-  local scatter_matrix = fill_scatter(inputs, meanv)
+  local scatter_matrix = fill_scatter(inputs)
   local eigenvals = scatter_matrix:eigen()
   -- scatter matrix is not filled with eigenvectors
   lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
@@ -676,6 +682,19 @@ local function learn_pca(inputs, max_inputs)
   return w
 end
 
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+  if not set.ann then
+    set.ann = {
+      symbols = set.symbols,
+      distance = 0,
+      digest = set.digest,
+      redis_key = ann_key,
+      version = 0,
+    }
+  end
+end
+
 -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
 local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
   -- Check training data sanity
@@ -684,7 +703,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
       meta_functions.rspamd_count_metatokens()
 
   -- Now we can train ann
-  local train_ann = create_ann(n, 3, rule)
+  local train_ann = create_ann(rule.max_inputs or n, 3, rule)
 
   if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
     -- Invalidate ANN as it is definitely invalid
@@ -749,12 +768,23 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         end
       end
 
-      train_ann:train1(inputs, outputs, {
-        lr = rule.train.learning_rate,
-        max_epoch = rule.train.max_iterations,
-        cb = train_cb,
-        pca = set.ann.pca
-      })
+      lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+          rule.prefix, set.name)
+
+      local ret,err = pcall(train_ann.train1, train_ann,
+          inputs, outputs, {
+            lr = rule.train.learning_rate,
+            max_epoch = rule.train.max_iterations,
+            cb = train_cb,
+            pca = set.ann.pca
+          })
+
+      if not ret then
+        rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+            rule.prefix, set.name, err)
+
+        return nil
+      end
 
       if not seen_nan then
         local out = train_ann:save()
@@ -806,14 +836,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         if set.ann.pca then
           pca_data = rspamd_util.zstd_compress(set.ann.pca:save())
         end
-        if not set.ann then
-          set.ann = {
-            symbols = set.symbols,
-            distance = 0,
-            digest = set.digest,
-            redis_key = ann_key,
-          }
-        end
+        fill_set_ann(set, ann_key)
         -- Deserialise ANN from the child process
         ann_trained = rspamd_kann.load(data)
         local version = (set.ann.version or 0) + 1
@@ -852,6 +875,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
     end
 
     if rule.max_inputs then
+      fill_set_ann(set, ann_key)
       -- Train PCA in the main process, presumably it is not that long
       set.ann.pca = learn_pca(inputs, rule.max_inputs)
     end
@@ -1045,7 +1069,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
           lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
               rule.prefix, set.name, ann_key)
         end
-        if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+        if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
           -- PCA table
           local _err,pca_data = rspamd_util.zstd_decompress(data[2])
           if pca_data then