diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-31 15:44:23 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-31 15:44:23 +0100 |
commit | edaf251d9e2e0a3e8ffc200c904dd6a90284442c (patch) | |
tree | d7ca749df74c14178a0ed2808b5446a4dbfb68e8 /src | |
parent | d13c3065b8e2ea5c3e4beec04f5d4ed5c5b84515 (diff) | |
download | rspamd-edaf251d9e2e0a3e8ffc200c904dd6a90284442c.tar.gz rspamd-edaf251d9e2e0a3e8ffc200c904dd6a90284442c.zip |
[Project] Neural: Use C version of scatter matrix producing
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/neural.lua | 32 |
1 files changed, 1 insertions, 31 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 10e49901f..5d2e9bfd5 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -641,40 +641,10 @@ local function register_lock_extender(rule, set, ev_base, ann_key) ) end --- This is an utility function for PCA training -local function fill_scatter(inputs) - local scatter_matrix = rspamd_tensor.new(2, #inputs[1], #inputs[1]) - local nsamples = #inputs - - -- Convert to a tensor where each row is an input dimension - inputs = rspamd_tensor.fromtable(inputs):transpose() - - local meanv = inputs:mean() - lua_util.debugm(N, 'means: %s', meanv) - - for i=1,nsamples do - local col = rspamd_tensor.new(1, #inputs) - for j=1,#inputs do - local x = inputs[j][i] - meanv[j] - col[j] = x - end - local prod = col:mul(col, false, true) - for ii=1,#prod do - for jj=1,#prod[1] do - scatter_matrix[ii][jj] = scatter_matrix[ii][jj] + prod[ii][jj] - end - end - end - - lua_util.debugm(N, 'scatter matrix: %s', scatter_matrix) - - return scatter_matrix -end - -- This function takes all inputs, applies PCA transformation and returns the final -- PCA matrix as rspamd_tensor local function learn_pca(inputs, max_inputs) - local scatter_matrix = fill_scatter(inputs) + local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs)) local eigenvals = scatter_matrix:eigen() -- scatter matrix is not filled with eigenvectors lua_util.debugm(N, 'eigenvalues: %s', eigenvals) |