]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Allow to set any layers number for fann rules
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 30 Jul 2017 14:34:44 +0000 (15:34 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 30 Jul 2017 14:34:44 +0000 (15:34 +0100)
src/plugins/lua/fann_redis.lua

index b8c3942b480380c3b026e0475ab511e39673783f..3e86007a80915bd6c1d5969c425262fc0ac920d5 100644 (file)
@@ -354,6 +354,17 @@ local function fann_scores_filter(task)
   end
 end
 
+local function create_fann(n, nlayers)
+  local layers = {}
+  local div = 1.0
+  for i in 1,nlayers - 1 do
+    table.insert(layers, math.floor(n / div))
+    div = div * 2
+  end
+  table.insert(layers, 1)
+  return rspamd_fann.create(nlayers, layers)
+end
+
 local function create_train_fann(rule, n, id)
   id = rule.prefix .. tostring(id)
   local prefix = gen_fann_prefix(rule, id)
@@ -368,18 +379,19 @@ local function create_train_fann(rule, n, id)
         'recreate ANN %s as it has a wrong number of inputs, version %s',
         prefix,
         fanns[id].version)
-      fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
+
+      fanns[id].fann_train = create_fann(n, rule.nlayers)
       fanns[id].fann = nil
     elseif fanns[id].version % rule.max_usages == 0 then
       -- Forget last fann
       rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
         fanns[id].version)
-      fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
+      fanns[id].fann_train = create_fann(n, rule.nlayers)
     else
       fanns[id].fann_train = fanns[id].fann
     end
   else
-    fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
+    fanns[id].fann_train = create_fann(n, rule.nlayers)
     fanns[id].version = 0
   end
 end