aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-31 15:44:23 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-31 15:44:23 +0100
commitedaf251d9e2e0a3e8ffc200c904dd6a90284442c (patch)
treed7ca749df74c14178a0ed2808b5446a4dbfb68e8 /src
parentd13c3065b8e2ea5c3e4beec04f5d4ed5c5b84515 (diff)
downloadrspamd-edaf251d9e2e0a3e8ffc200c904dd6a90284442c.tar.gz
rspamd-edaf251d9e2e0a3e8ffc200c904dd6a90284442c.zip
[Project] Neural: Use C version of scatter matrix producing
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/neural.lua32
1 files changed, 1 insertions, 31 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 10e49901f..5d2e9bfd5 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -641,40 +641,10 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
)
end
--- This is an utility function for PCA training
-local function fill_scatter(inputs)
- local scatter_matrix = rspamd_tensor.new(2, #inputs[1], #inputs[1])
- local nsamples = #inputs
-
- -- Convert to a tensor where each row is an input dimension
- inputs = rspamd_tensor.fromtable(inputs):transpose()
-
- local meanv = inputs:mean()
- lua_util.debugm(N, 'means: %s', meanv)
-
- for i=1,nsamples do
- local col = rspamd_tensor.new(1, #inputs)
- for j=1,#inputs do
- local x = inputs[j][i] - meanv[j]
- col[j] = x
- end
- local prod = col:mul(col, false, true)
- for ii=1,#prod do
- for jj=1,#prod[1] do
- scatter_matrix[ii][jj] = scatter_matrix[ii][jj] + prod[ii][jj]
- end
- end
- end
-
- lua_util.debugm(N, 'scatter matrix: %s', scatter_matrix)
-
- return scatter_matrix
-end
-
-- This function takes all inputs, applies PCA transformation and returns the final
-- PCA matrix as rspamd_tensor
local function learn_pca(inputs, max_inputs)
- local scatter_matrix = fill_scatter(inputs)
+ local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
local eigenvals = scatter_matrix:eigen()
-- scatter matrix is not filled with eigenvectors
lua_util.debugm(N, 'eigenvalues: %s', eigenvals)