diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-09-16 21:11:44 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-09-16 21:11:44 +0100 |
commit | 3ac673af36fa8cb19ec535cafc29f21013ebe461 (patch) | |
tree | 9eb06f13a24dfb43077874da50736adf7461d28e /src | |
parent | edc592085458130b6e4aebeccc7ce45ea920d97a (diff) | |
download | rspamd-3ac673af36fa8cb19ec535cafc29f21013ebe461.tar.gz rspamd-3ac673af36fa8cb19ec535cafc29f21013ebe461.zip |
[Feature] Allow to disable torch and skip train samples for ANN
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index e2a7eb4f5..f07a84033 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -46,6 +46,7 @@ local default_options = { max_iterations = 25, -- Torch style mse = 0.001, autotrain = true, + train_prob = 1.0, }, use_settings = false, per_user = false, @@ -431,7 +432,7 @@ local function create_train_fann(rule, n, id) fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].fann = nil rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) - elseif fanns[id].version % rule.train.max_usages == 0 then + elseif rule.train.max_usages > 0 and fanns[id].version % rule.train.max_usages == 0 then -- Forget last fann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, fanns[id].version) @@ -540,7 +541,7 @@ local function fann_train_callback(rule, task, score, required_score, id) local function learn_vec_cb(err) if err then - rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err) + rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err) else rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes", rule['name'], k, vec_len) @@ -549,6 +550,11 @@ local function fann_train_callback(rule, task, score, required_score, id) local function can_train_cb(err, data) if not err and tonumber(data) > 0 then + local coin = math.random() + if coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return + end local fann_data = task:get_symbols_tokens() local mt = meta_functions.rspamd_gen_metatokens(task) -- Add filtered meta tokens @@ -1069,6 +1075,10 @@ else rules['RFANN'] = opts end + if opts.disable_torch then + use_torch = false + end + local id = rspamd_config:register_symbol({ name = 'FANN_CHECK', type = 'postfilter,nostat', |