diff options
-rw-r--r-- | src/plugins/lua/neural.lua | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 99efe720e..9df2f1c55 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -31,6 +31,9 @@ local ts = require("tableshape").types local lua_verdict = require "lua_verdict" local N = "neural" +-- Used in prefix to avoid wrong ANN to be loaded +local plugin_ver = '2' + -- Module vars local default_options = { train = { @@ -52,6 +55,7 @@ local default_options = { lock_expire = 600, learning_spawned = false, ann_expire = 60 * 60 * 24 * 2, -- 2 days + hidden_layer_mult = 1.5, -- number of neurons in the hidden layer symbol_spam = 'NEURAL_SPAM', symbol_ham = 'NEURAL_HAM', } @@ -251,8 +255,8 @@ end local function redis_ann_prefix(rule, settings_name) -- We also need to count metatokens: local n = meta_functions.version - return string.format('%s_%s_%d_%s', - settings.prefix, rule.prefix, n, settings_name) + return string.format('%s%d_%s_%d_%s', + settings.prefix, plugin_ver, rule.prefix, n, settings_name) end -- Creates and stores ANN profile in Redis @@ -337,9 +341,9 @@ local function ann_scores_filter(task) end end -local function create_ann(n, nlayers) +local function create_ann(n, nlayers, rule) -- We ignore number of layers so far when using kann - local nhidden = math.floor((n + 1) / 2) + local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) local t = rspamd_kann.layer.input(n) t = rspamd_kann.transform.relu(t) t = rspamd_kann.layer.dense(t, nhidden); @@ -364,14 +368,18 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) -- Apply sampling local skip_rate = 1.0 - nham / (nspam + 1) if coin < skip_rate - train_opts.classes_bias then - rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type, - skip_rate - train_opts.classes_bias) + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) return false end end return true else -- Enough learns - rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', learn_type, + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', + learn_type, nspam) end else @@ -380,8 +388,11 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) -- Apply sampling local skip_rate = 1.0 - nspam / (nham + 1) if coin < skip_rate - train_opts.classes_bias then - rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type, - skip_rate - train_opts.classes_bias) + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) return false end end @@ -625,7 +636,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve meta_functions.rspamd_count_metatokens() -- Now we can train ann - local train_ann = create_ann(n, 3) + local train_ann = create_ann(n, 3, rule) if #ham_vec + #spam_vec < rule.train.max_trains / 2 then -- Invalidate ANN as it is definitely invalid |