diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-07-30 15:34:44 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-07-30 15:34:44 +0100 |
commit | a920318bfc347709aa1453ecb1867f0e1e2163ca (patch) | |
tree | e368df697a3627eb0f19c55309afc64033095ae0 /src | |
parent | cf8e44673ee28480e5090fe7004da10fe1ab0ce3 (diff) | |
download | rspamd-a920318bfc347709aa1453ecb1867f0e1e2163ca.tar.gz rspamd-a920318bfc347709aa1453ecb1867f0e1e2163ca.zip |
[Fix] Allow to set any layers number for fann rules
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index b8c3942b4..3e86007a8 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -354,6 +354,17 @@ local function fann_scores_filter(task) end end +local function create_fann(n, nlayers) + local layers = {} + local div = 1.0 + for i in 1,nlayers - 1 do + table.insert(layers, math.floor(n / div)) + div = div * 2 + end + table.insert(layers, 1) + return rspamd_fann.create(nlayers, layers) +end + local function create_train_fann(rule, n, id) id = rule.prefix .. tostring(id) local prefix = gen_fann_prefix(rule, id) @@ -368,18 +379,19 @@ local function create_train_fann(rule, n, id) 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix, fanns[id].version) - fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) + + fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].fann = nil elseif fanns[id].version % rule.max_usages == 0 then -- Forget last fann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, fanns[id].version) - fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) + fanns[id].fann_train = create_fann(n, rule.nlayers) else fanns[id].fann_train = fanns[id].fann end else - fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) + fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].version = 0 end end |