aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/plugins/lua/neural.lua31
1 files changed, 21 insertions, 10 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 99efe720e..9df2f1c55 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -31,6 +31,9 @@ local ts = require("tableshape").types
local lua_verdict = require "lua_verdict"
local N = "neural"
+-- Used in prefix to avoid wrong ANN to be loaded
+local plugin_ver = '2'
+
-- Module vars
local default_options = {
train = {
@@ -52,6 +55,7 @@ local default_options = {
lock_expire = 600,
learning_spawned = false,
ann_expire = 60 * 60 * 24 * 2, -- 2 days
+ hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
symbol_spam = 'NEURAL_SPAM',
symbol_ham = 'NEURAL_HAM',
}
@@ -251,8 +255,8 @@ end
local function redis_ann_prefix(rule, settings_name)
-- We also need to count metatokens:
local n = meta_functions.version
- return string.format('%s_%s_%d_%s',
- settings.prefix, rule.prefix, n, settings_name)
+ return string.format('%s%d_%s_%d_%s',
+ settings.prefix, plugin_ver, rule.prefix, n, settings_name)
end
-- Creates and stores ANN profile in Redis
@@ -337,9 +341,9 @@ local function ann_scores_filter(task)
end
end
-local function create_ann(n, nlayers)
+local function create_ann(n, nlayers, rule)
-- We ignore number of layers so far when using kann
- local nhidden = math.floor((n + 1) / 2)
+ local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
local t = rspamd_kann.layer.input(n)
t = rspamd_kann.transform.relu(t)
t = rspamd_kann.layer.dense(t, nhidden);
@@ -364,14 +368,18 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
-- Apply sampling
local skip_rate = 1.0 - nham / (nspam + 1)
if coin < skip_rate - train_opts.classes_bias then
- rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
- skip_rate - train_opts.classes_bias)
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
return false
end
end
return true
else -- Enough learns
- rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', learn_type,
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
+ learn_type,
nspam)
end
else
@@ -380,8 +388,11 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham)
-- Apply sampling
local skip_rate = 1.0 - nspam / (nham + 1)
if coin < skip_rate - train_opts.classes_bias then
- rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
- skip_rate - train_opts.classes_bias)
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
return false
end
end
@@ -625,7 +636,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
meta_functions.rspamd_count_metatokens()
-- Now we can train ann
- local train_ann = create_ann(n, 3)
+ local train_ann = create_ann(n, 3, rule)
if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
-- Invalidate ANN as it is definitely invalid