From 362dc834f1be24b107a0f3f593e743ce2ae66a04 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 27 Aug 2020 15:46:51 +0100 Subject: [PATCH] [Project] Neural: Implement PCA learning --- src/plugins/lua/neural.lua | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 32a751987..91caa8e07 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -638,6 +638,44 @@ local function register_lock_extender(rule, set, ev_base, ann_key) ) end +-- This is an utility function for PCA training +local function fill_scatter(inputs, meanv) + local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs) + local row_len = #inputs[1] + + for i=1,row_len do + local col = rspamd_tensor.new(1, #inputs) + for j=1,#inputs do + local x = inputs[j][i] - meanv[j] + col[j] = x + end + local prod = col:mul(col, false, true) + for ii=1,#prod do + for jj=1,#prod[1] do + scatter_matrix[ii][jj] = scatter_matrix[ii][jj] + prod[ii][jj] + end + end + end + + return scatter_matrix +end + +-- This function takes all inputs, applies PCA transformation and returns the final +-- PCA matrix as rspamd_tensor +local function learn_pca(inputs, max_inputs) + local meanv = inputs:mean() + local scatter_matrix = fill_scatter(inputs, meanv) + local eigenvals = scatter_matrix:eigen() + -- scatter matrix is not filled with eigenvectors + lua_util.debugm(N, 'eigenvalues: %s', eigenvals) + local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) + for i=1,max_inputs do + w[i] = scatter_matrix[#scatter_matrix - i + 1] + end + + return w +end + -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec) -- Check training data sanity @@ -715,6 +753,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve lr = rule.train.learning_rate, max_epoch = rule.train.max_iterations, cb = train_cb, + pca = set.ann.pca }) if not seen_nan then @@ -812,6 +851,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve end end + if rule.max_inputs then + -- Train PCA in the main process, presumably it is not that long + set.ann.pca = learn_pca(inputs, rule.max_inputs) + end + worker:spawn_process{ func = train, on_complete = ann_trained, -- 2.39.5