diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index b8c3942b4..3e86007a8 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -354,6 +354,17 @@ local function fann_scores_filter(task) end end +local function create_fann(n, nlayers) + local layers = {} + local div = 1.0 + for i in 1,nlayers - 1 do + table.insert(layers, math.floor(n / div)) + div = div * 2 + end + table.insert(layers, 1) + return rspamd_fann.create(nlayers, layers) +end + local function create_train_fann(rule, n, id) id = rule.prefix .. tostring(id) local prefix = gen_fann_prefix(rule, id) @@ -368,18 +379,19 @@ local function create_train_fann(rule, n, id) 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix, fanns[id].version) - fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) + + fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].fann = nil elseif fanns[id].version % rule.max_usages == 0 then -- Forget last fann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, fanns[id].version) - fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) + fanns[id].fann_train = create_fann(n, rule.nlayers) else fanns[id].fann_train = fanns[id].fann end else - fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1) + fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].version = 0 end end |