aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-30 15:34:44 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-30 15:34:44 +0100
commita920318bfc347709aa1453ecb1867f0e1e2163ca (patch)
treee368df697a3627eb0f19c55309afc64033095ae0 /src/plugins
parentcf8e44673ee28480e5090fe7004da10fe1ab0ce3 (diff)
downloadrspamd-a920318bfc347709aa1453ecb1867f0e1e2163ca.tar.gz
rspamd-a920318bfc347709aa1453ecb1867f0e1e2163ca.zip
[Fix] Allow to set any layers number for fann rules
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/fann_redis.lua18
1 files changed, 15 insertions, 3 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index b8c3942b4..3e86007a8 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -354,6 +354,17 @@ local function fann_scores_filter(task)
end
end
+local function create_fann(n, nlayers)
+ local layers = {}
+ local div = 1.0
+ for i in 1,nlayers - 1 do
+ table.insert(layers, math.floor(n / div))
+ div = div * 2
+ end
+ table.insert(layers, 1)
+ return rspamd_fann.create(nlayers, layers)
+end
+
local function create_train_fann(rule, n, id)
id = rule.prefix .. tostring(id)
local prefix = gen_fann_prefix(rule, id)
@@ -368,18 +379,19 @@ local function create_train_fann(rule, n, id)
'recreate ANN %s as it has a wrong number of inputs, version %s',
prefix,
fanns[id].version)
- fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
+
+ fanns[id].fann_train = create_fann(n, rule.nlayers)
fanns[id].fann = nil
elseif fanns[id].version % rule.max_usages == 0 then
-- Forget last fann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
fanns[id].version)
- fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
+ fanns[id].fann_train = create_fann(n, rule.nlayers)
else
fanns[id].fann_train = fanns[id].fann
end
else
- fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
+ fanns[id].fann_train = create_fann(n, rule.nlayers)
fanns[id].version = 0
end
end