summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-16 21:11:44 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-16 21:11:44 +0100
commit3ac673af36fa8cb19ec535cafc29f21013ebe461 (patch)
tree9eb06f13a24dfb43077874da50736adf7461d28e /src
parentedc592085458130b6e4aebeccc7ce45ea920d97a (diff)
downloadrspamd-3ac673af36fa8cb19ec535cafc29f21013ebe461.tar.gz
rspamd-3ac673af36fa8cb19ec535cafc29f21013ebe461.zip
[Feature] Allow to disable torch and skip train samples for ANN
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/fann_redis.lua14
1 files changed, 12 insertions, 2 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index e2a7eb4f5..f07a84033 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -46,6 +46,7 @@ local default_options = {
max_iterations = 25, -- Torch style
mse = 0.001,
autotrain = true,
+ train_prob = 1.0,
},
use_settings = false,
per_user = false,
@@ -431,7 +432,7 @@ local function create_train_fann(rule, n, id)
fanns[id].fann_train = create_fann(n, rule.nlayers)
fanns[id].fann = nil
rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
- elseif fanns[id].version % rule.train.max_usages == 0 then
+ elseif rule.train.max_usages > 0 and fanns[id].version % rule.train.max_usages == 0 then
-- Forget last fann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
fanns[id].version)
@@ -540,7 +541,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
local function learn_vec_cb(err)
if err then
- rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err)
+ rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err)
else
rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
rule['name'], k, vec_len)
@@ -549,6 +550,11 @@ local function fann_train_callback(rule, task, score, required_score, id)
local function can_train_cb(err, data)
if not err and tonumber(data) > 0 then
+ local coin = math.random()
+ if coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return
+ end
local fann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
@@ -1069,6 +1075,10 @@ else
rules['RFANN'] = opts
end
+ if opts.disable_torch then
+ use_torch = false
+ end
+
local id = rspamd_config:register_symbol({
name = 'FANN_CHECK',
type = 'postfilter,nostat',