]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Implement PCA learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 14:46:51 +0000 (15:46 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 14:46:51 +0000 (15:46 +0100)
src/plugins/lua/neural.lua

index 32a75198786788b6ca567f0f3972e2aed6b57922..91caa8e07a28fe698fe2694d177d81bdbecd3e33 100644 (file)
@@ -638,6 +638,44 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
   )
 end
 
+-- This is an utility function for PCA training
+local function fill_scatter(inputs, meanv)
+  local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
+  local row_len = #inputs[1]
+
+  for i=1,row_len do
+    local col = rspamd_tensor.new(1, #inputs)
+    for j=1,#inputs do
+      local x = inputs[j][i] - meanv[j]
+      col[j] = x
+    end
+    local prod = col:mul(col, false, true)
+    for ii=1,#prod do
+      for jj=1,#prod[1] do
+        scatter_matrix[ii][jj] = scatter_matrix[ii][jj] + prod[ii][jj]
+      end
+    end
+  end
+
+  return scatter_matrix
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+  local meanv = inputs:mean()
+  local scatter_matrix = fill_scatter(inputs, meanv)
+  local eigenvals = scatter_matrix:eigen()
+  -- scatter matrix is not filled with eigenvectors
+  lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+  local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+  for i=1,max_inputs do
+    w[i] = scatter_matrix[#scatter_matrix - i + 1]
+  end
+
+  return w
+end
+
 -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
 local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
   -- Check training data sanity
@@ -715,6 +753,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         lr = rule.train.learning_rate,
         max_epoch = rule.train.max_iterations,
         cb = train_cb,
+        pca = set.ann.pca
       })
 
       if not seen_nan then
@@ -812,6 +851,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
       end
     end
 
+    if rule.max_inputs then
+      -- Train PCA in the main process, presumably it is not that long
+      set.ann.pca = learn_pca(inputs, rule.max_inputs)
+    end
+
     worker:spawn_process{
       func = train,
       on_complete = ann_trained,