aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-27 23:51:38 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-27 23:51:38 +0100
commit913ac147bbc4e706095003fb8c16d24e2187a77f (patch)
tree20d87d95ef07dcb0719e7f9abe2701b130069e37 /src/plugins
parent68573e994099396d35181186f4e9d8e0cddbdd53 (diff)
downloadrspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.tar.gz
rspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.zip
[Project] Neural: Fix PCA based learning
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/neural.lua21
1 files changed, 12 insertions, 9 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 5b4ff8b3b..0258fb0b0 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -640,17 +640,15 @@ end
-- This is an utility function for PCA training
local function fill_scatter(inputs)
- local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
- local row_len = #inputs[1]
+ local scatter_matrix = rspamd_tensor.new(2, #inputs[1], #inputs[1])
+ local nsamples = #inputs
- if type(inputs) == 'table' then
- -- Convert to a tensor
- inputs = rspamd_tensor.fromtable(inputs)
- end
+ -- Convert to a tensor where each row is an input dimension
+ inputs = rspamd_tensor.fromtable(inputs):transpose()
local meanv = inputs:mean()
- for i=1,row_len do
+ for i=1,nsamples do
local col = rspamd_tensor.new(1, #inputs)
for j=1,#inputs do
local x = inputs[j][i] - meanv[j]
@@ -679,6 +677,8 @@ local function learn_pca(inputs, max_inputs)
w[i] = scatter_matrix[#scatter_matrix - i + 1]
end
+ lua_util.debugm(N, 'pca matrix: %s', w)
+
return w
end
@@ -856,8 +856,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
rspamd_logger.infox(rspamd_config,
- 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
- rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
+ 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+ rule.prefix, set.name,
+ #data, #ann_data,
+ #(set.ann.pca or {}), #(pca_data or {}),
+ set.ann.redis_key, ann_key)
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},