diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-27 23:51:38 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-27 23:51:38 +0100 |
commit | 913ac147bbc4e706095003fb8c16d24e2187a77f (patch) | |
tree | 20d87d95ef07dcb0719e7f9abe2701b130069e37 /src/plugins | |
parent | 68573e994099396d35181186f4e9d8e0cddbdd53 (diff) | |
download | rspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.tar.gz rspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.zip |
[Project] Neural: Fix PCA based learning
Diffstat (limited to 'src/plugins')
-rw-r--r-- | src/plugins/lua/neural.lua | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 5b4ff8b3b..0258fb0b0 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -640,17 +640,15 @@ end -- This is an utility function for PCA training local function fill_scatter(inputs) - local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs) - local row_len = #inputs[1] + local scatter_matrix = rspamd_tensor.new(2, #inputs[1], #inputs[1]) + local nsamples = #inputs - if type(inputs) == 'table' then - -- Convert to a tensor - inputs = rspamd_tensor.fromtable(inputs) - end + -- Convert to a tensor where each row is an input dimension + inputs = rspamd_tensor.fromtable(inputs):transpose() local meanv = inputs:mean() - for i=1,row_len do + for i=1,nsamples do local col = rspamd_tensor.new(1, #inputs) for j=1,#inputs do local x = inputs[j][i] - meanv[j] @@ -679,6 +677,8 @@ local function learn_pca(inputs, max_inputs) w[i] = scatter_matrix[#scatter_matrix - i + 1] end + lua_util.debugm(N, 'pca matrix: %s', w) + return w end @@ -856,8 +856,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve local profile_serialized = ucl.to_format(profile, 'json-compact', true) rspamd_logger.infox(rspamd_config, - 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)', - rule.prefix, set.name, #data, set.ann.redis_key, ann_key) + 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', + rule.prefix, set.name, + #data, #ann_data, + #(set.ann.pca or {}), #(pca_data or {}), + set.ann.redis_key, ann_key) lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, |