]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Rename routines in neural plugin
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 8 Mar 2018 12:34:20 +0000 (12:34 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 8 Mar 2018 12:34:20 +0000 (12:34 +0000)
src/plugins/neural.lua

index 117881b31219be1331b91767dc783dd8fea0107d..b2c7adcfa797d2aeaa47026e229305e88ec112c2 100644 (file)
@@ -14,8 +14,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ]]--
 
--- This plugin is a concept of FANN scores adjustment
--- NOT FOR PRODUCTION USE so far
 
 if confighelp then
   return
@@ -24,14 +22,14 @@ end
 local rspamd_logger = require "rspamd_logger"
 local rspamd_fann = require "rspamd_fann"
 local rspamd_util = require "rspamd_util"
-local rspamd_redis = require "lua_redis"
+local lua_redis = require "lua_redis"
 local lua_util = require "lua_util"
 local fun = require "fun"
 local meta_functions = require "lua_meta"
 local use_torch = false
 local torch
 local nn
-local N = "fann_redis"
+local N = "neural"
 
 if rspamd_config:has_torch() then
   use_torch = true
@@ -67,10 +65,14 @@ local settings = {
 }
 
 -- ANNs indexed by settings id
-local fanns = {
+local anns = {
 }
 
-local opts = rspamd_config:get_all_opt("fann_redis")
+local opts = rspamd_config:get_all_opt("neural")
+if not opts then
+  -- Legacy
+  opts = rspamd_config:get_all_opt("fann_redis")
+end
 
 
 -- Lua script to train a row
@@ -205,21 +207,21 @@ local redis_save_unlock_id = nil
 local redis_params
 
 local function load_scripts(params)
-  redis_can_train_id = rspamd_redis.add_redis_script(redis_lua_script_can_train,
+  redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train,
     params)
-  redis_maybe_load_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_load,
+  redis_maybe_load_id = lua_redis.add_redis_script(redis_lua_script_maybe_load,
     params)
-  redis_maybe_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_invalidate,
+  redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
     params)
-  redis_locked_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_locked_invalidate,
+  redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate,
     params)
-  redis_maybe_lock_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_lock,
+  redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
     params)
-  redis_save_unlock_id = rspamd_redis.add_redis_script(redis_lua_script_save_unlock,
+  redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock,
     params)
 end
 
-local function gen_fann_prefix(rule, id)
+local function gen_ann_prefix(rule, id)
   local cksum = rspamd_config:get_symbols_cksum():hex()
   -- We also need to count metatokens:
   local n = meta_functions.rspamd_count_metatokens()
@@ -234,7 +236,7 @@ local function gen_fann_prefix(rule, id)
   end
 end
 
-local function is_fann_valid(rule, prefix, ann)
+local function is_ann_valid(rule, prefix, ann)
   if ann then
     local n = rspamd_config:get_symbols_count() +
         meta_functions.rspamd_count_metatokens()
@@ -260,7 +262,7 @@ local function is_fann_valid(rule, prefix, ann)
   end
 end
 
-local function fann_scores_filter(task)
+local function ann_scores_filter(task)
 
   for _,rule in pairs(settings.rules) do
     local id = '0'
@@ -275,23 +277,23 @@ local function fann_scores_filter(task)
       id = id .. r
     end
 
-    if fanns[id] and fanns[id].fann then
-      local fann_data = task:get_symbols_tokens()
+    if anns[id] and anns[id].ann then
+      local ann_data = task:get_symbols_tokens()
       local mt = meta_functions.rspamd_gen_metatokens(task)
       -- Add filtered meta tokens
-      fun.each(function(e) table.insert(fann_data, e) end, mt)
+      fun.each(function(e) table.insert(ann_data, e) end, mt)
 
       local score
       if use_torch then
-        local out = fanns[id].fann:forward(torch.Tensor(fann_data))
+        local out = anns[id].ann:forward(torch.Tensor(ann_data))
         score = out[1]
       else
-        local out = fanns[id].fann:test(fann_data)
+        local out = anns[id].ann:test(ann_data)
         score = out[1]
       end
 
       local symscore = string.format('%.3f', score)
-      rspamd_logger.infox(task, 'fann score: %s', symscore)
+      rspamd_logger.infox(task, 'ann score: %s', symscore)
 
       if score > 0 then
         local result = score
@@ -310,7 +312,7 @@ local function fann_scores_filter(task)
   end
 end
 
-local function create_fann(n, nlayers)
+local function create_ann(n, nlayers)
   if use_torch then
     -- We ignore number of layers so far when using torch
     local ann = nn.Sequential()
@@ -334,36 +336,36 @@ local function create_fann(n, nlayers)
   end
 end
 
-local function create_train_fann(rule, n, id)
-  local prefix = gen_fann_prefix(rule, id)
-  if not fanns[id] then
-    fanns[id] = {}
+local function create_train_ann(rule, n, id)
+  local prefix = gen_ann_prefix(rule, id)
+  if not anns[id] then
+    anns[id] = {}
   end
   -- Fix that for flexibe layers number
-  if fanns[id].fann then
-    if not is_fann_valid(rule, prefix, fanns[id].fann) then
-      fanns[id].fann_train = create_fann(n, rule.nlayers)
-      fanns[id].fann = nil
+  if anns[id].ann then
+    if not is_ann_valid(rule, prefix, anns[id].ann) then
+      anns[id].ann_train = create_ann(n, rule.nlayers)
+      anns[id].ann = nil
       rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
-    elseif rule.train.max_usages > 0 and fanns[id].version % rule.train.max_usages == 0 then
-      -- Forget last fann
+    elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then
+      -- Forget last ann
       rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
-        fanns[id].version)
-      fanns[id].fann_train = create_fann(n, rule.nlayers)
+        anns[id].version)
+      anns[id].ann_train = create_ann(n, rule.nlayers)
     else
-      fanns[id].fann_train = fanns[id].fann
+      anns[id].ann_train = anns[id].ann
       rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
     end
   else
-    fanns[id].fann_train = create_fann(n, rule.nlayers)
+    anns[id].ann_train = create_ann(n, rule.nlayers)
     rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
-    fanns[id].version = 0
+    anns[id].version = 0
   end
 end
 
-local function load_or_invalidate_fann(rule, data, id, ev_base)
+local function load_or_invalidate_ann(rule, data, id, ev_base)
   local ver = data[2]
-  local prefix = gen_fann_prefix(rule, id)
+  local prefix = gen_ann_prefix(rule, id)
 
   if not ver or not tonumber(ver) then
     rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
@@ -384,33 +386,33 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
     end
   end
 
-  if is_fann_valid(rule, prefix, ann) then
-    if not fanns[id] then fanns[id] = {} end
-    fanns[id].fann = ann
+  if is_ann_valid(rule, prefix, ann) then
+    if not anns[id] then anns[id] = {} end
+    anns[id].ann = ann
     rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
       prefix, ver)
-    fanns[id].version = tonumber(ver)
+    anns[id].version = tonumber(ver)
   else
     local function redis_invalidate_cb(_err, _data)
       if _err then
         rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
       elseif type(_data) == 'string' then
         rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
-        fanns[id].version = 0
+        anns[id].version = 0
       end
     end
     -- Invalidate ANN
     rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
-    rspamd_redis.exec_redis_script(redis_maybe_invalidate_id,
+    lua_redis.exec_redis_script(redis_maybe_invalidate_id,
       {ev_base = ev_base, is_write = true},
       redis_invalidate_cb,
       {prefix})
   end
 end
 
-local function fann_train_callback(rule, task, score, required_score, id)
+local function ann_train_callback(rule, task, score, required_score, id)
   local train_opts = rule['train']
-  local fname,suffix = gen_fann_prefix(rule, id)
+  local fname,suffix = gen_ann_prefix(rule, id)
 
   local learn_spam, learn_ham
 
@@ -460,16 +462,16 @@ local function fann_train_callback(rule, task, score, required_score, id)
           rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
           return
         end
-        local fann_data = task:get_symbols_tokens()
+        local ann_data = task:get_symbols_tokens()
         local mt = meta_functions.rspamd_gen_metatokens(task)
         -- Add filtered meta tokens
-        fun.each(function(e) table.insert(fann_data, e) end, mt)
+        fun.each(function(e) table.insert(ann_data, e) end, mt)
         -- Check NaNs in train data
-        if fun.all(function(e) return e == e end, fann_data) then
-          local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
+        if fun.all(function(e) return e == e end, ann_data) then
+          local str = rspamd_util.zstd_compress(table.concat(ann_data, ';'))
           vec_len = #str
 
-          rspamd_redis.redis_make_request(task,
+          lua_redis.redis_make_request(task,
             rule.redis,
             nil,
             true, -- is write
@@ -479,7 +481,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
           )
         else
           rspamd_logger.errx(task, "do not store learn vector as it contains %s NaN values",
-            fun.length(fun.filter(function(e) return e ~= e end, fann_data)))
+            fun.length(fun.filter(function(e) return e ~= e end, ann_data)))
         end
 
       else
@@ -492,18 +494,18 @@ local function fann_train_callback(rule, task, score, required_score, id)
       end
     end
 
-    rspamd_redis.exec_redis_script(redis_can_train_id,
+    lua_redis.exec_redis_script(redis_can_train_id,
       {task = task, is_write = true},
       can_train_cb,
-      {gen_fann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)})
+      {gen_ann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)})
   end
 end
 
-local function train_fann(rule, _, ev_base, elt, worker)
+local function train_ann(rule, _, ev_base, elt, worker)
   local spam_elts = {}
   local ham_elts = {}
   elt = tostring(elt)
-  local prefix = gen_fann_prefix(rule, elt)
+  local prefix = gen_ann_prefix(rule, elt)
 
   local function redis_unlock_cb(err)
     if err then
@@ -516,7 +518,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s',
         prefix, err)
-      rspamd_redis.redis_make_request_taskless(ev_base,
+      lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
@@ -535,7 +537,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
     if errcode ~= 0 then
       rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
         prefix, errmsg)
-      rspamd_redis.redis_make_request_taskless(ev_base,
+      lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
@@ -550,16 +552,16 @@ local function train_fann(rule, _, ev_base, elt, worker)
       local ann_data
       if use_torch then
         local f = torch.MemoryFile()
-        f:writeObject(fanns[elt].fann_train)
+        f:writeObject(anns[elt].ann_train)
         ann_data = rspamd_util.zstd_compress(f:storage():string())
       else
-        ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+        ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data())
       end
 
-      fanns[elt].version = fanns[elt].version + 1
-      fanns[elt].fann = fanns[elt].fann_train
-      fanns[elt].fann_train = nil
-      rspamd_redis.exec_redis_script(redis_save_unlock_id,
+      anns[elt].version = anns[elt].version + 1
+      anns[elt].ann = anns[elt].ann_train
+      anns[elt].ann_train = nil
+      lua_redis.exec_redis_script(redis_save_unlock_id,
         {ev_base = ev_base, is_write = true},
         redis_save_cb,
         {prefix, tostring(ann_data), tostring(rule.ann_expire)})
@@ -571,7 +573,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
         prefix, err)
-      rspamd_redis.redis_make_request_taskless(ev_base,
+      lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
@@ -586,12 +588,12 @@ local function train_fann(rule, _, ev_base, elt, worker)
       local ann_data
       local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
       ann_data = rspamd_util.zstd_compress(f:storage():string())
-      fanns[elt].fann_train = f:readObject()
+      anns[elt].ann_train = f:readObject()
 
-      fanns[elt].version = fanns[elt].version + 1
-      fanns[elt].fann = fanns[elt].fann_train
-      fanns[elt].fann_train = nil
-      rspamd_redis.exec_redis_script(redis_save_unlock_id,
+      anns[elt].version = anns[elt].version + 1
+      anns[elt].ann = anns[elt].ann_train
+      anns[elt].ann_train = nil
+      lua_redis.exec_redis_script(redis_save_unlock_id,
         {ev_base = ev_base, is_write = true},
         redis_save_cb,
         {prefix, tostring(ann_data), tostring(rule.ann_expire)})
@@ -602,7 +604,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
         prefix, err)
-      rspamd_redis.redis_make_request_taskless(ev_base,
+      lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
@@ -625,10 +627,10 @@ local function train_fann(rule, _, ev_base, elt, worker)
         return #elts == n
       end
 
-      -- Now we can train fann
-      if not fanns[elt] or not fanns[elt].fann_train then
-        -- Create fann if it does not exist
-        create_train_fann(rule, n, elt)
+      -- Now we can train ann
+      if not anns[elt] or not anns[elt].ann_train then
+        -- Create ann if it does not exist
+        create_train_ann(rule, n, elt)
       end
 
       if #spam_elts + #ham_elts < rule.train.max_trains / 2 then
@@ -638,12 +640,12 @@ local function train_fann(rule, _, ev_base, elt, worker)
             rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
           elseif type(_data) == 'string' then
             rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
-            fanns[elt].version = 0
+            anns[elt].version = 0
           end
         end
         -- Invalidate ANN
         rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
-        rspamd_redis.exec_redis_script(redis_locked_invalidate_id,
+        lua_redis.exec_redis_script(redis_locked_invalidate_id,
           {ev_base = ev_base, is_write = true},
           redis_invalidate_cb,
           {prefix})
@@ -665,7 +667,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
               torch.setnumthreads(rule.train.learn_threads)
             end
             local criterion = nn.MSECriterion()
-            local trainer = nn.StochasticGradient(fanns[elt].fann_train,
+            local trainer = nn.StochasticGradient(anns[elt].ann_train,
               criterion)
             trainer.learning_rate = 0.01
             trainer.verbose = false
@@ -677,7 +679,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
 
             trainer:train(dataset)
             local out = torch.MemoryFile()
-            out:writeObject(fanns[elt].fann_train)
+            out:writeObject(anns[elt].ann_train)
             local st = out:storage():string()
             return st
           end
@@ -698,7 +700,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
           end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
           rule.learning_spawned = true
           rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
-          fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
+          anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
             ev_base, {
               max_epochs = rule.train.max_epoch,
               desired_mse = rule.train.mse
@@ -713,7 +715,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
         prefix, err)
-      rspamd_redis.redis_make_request_taskless(ev_base,
+      lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
@@ -728,7 +730,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
         local _,str = rspamd_util.zstd_decompress(tok)
         return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';')))
       end, data))
-      rspamd_redis.redis_make_request_taskless(ev_base,
+      lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
@@ -746,7 +748,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
         prefix, err)
     elseif type(data) == 'number' then
       -- Can train ANN
-      rspamd_redis.redis_make_request_taskless(ev_base,
+      lua_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
         rule.redis,
         nil,
@@ -768,7 +770,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
             end
           end
           if rule.learning_spawned then
-            rspamd_redis.redis_make_request_taskless(ev_base,
+            lua_redis.redis_make_request_taskless(ev_base,
               rspamd_config,
               rule.redis,
               nil,
@@ -793,20 +795,20 @@ local function train_fann(rule, _, ev_base, elt, worker)
     rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
     return
   end
-  rspamd_redis.exec_redis_script(redis_maybe_lock_id,
+  lua_redis.exec_redis_script(redis_maybe_lock_id,
     {ev_base = ev_base, is_write = true},
     redis_lock_cb,
     {prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()})
 end
 
-local function maybe_train_fanns(rule, cfg, ev_base, worker)
+local function maybe_train_anns(rule, cfg, ev_base, worker)
   local function members_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
     elseif type(data) == 'table' then
       fun.each(function(elt)
         elt = tostring(elt)
-        local prefix = gen_fann_prefix(rule, elt)
+        local prefix = gen_ann_prefix(rule, elt)
         rspamd_logger.infox(cfg, "check ANN %s", prefix)
         local redis_len_cb = function(_err, _data)
           if _err then
@@ -817,7 +819,7 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
               rspamd_logger.infox(rspamd_config,
                 'need to learn ANN %s after %s learn vectors (%s required)',
                 prefix, tonumber(_data), rule.train.max_trains)
-              train_fann(rule, cfg, ev_base, elt, worker)
+              train_ann(rule, cfg, ev_base, elt, worker)
             else
               rspamd_logger.infox(rspamd_config,
                 'no need to learn ANN %s %s learn vectors (%s required)',
@@ -826,7 +828,7 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
           end
         end
 
-        rspamd_redis.redis_make_request_taskless(ev_base,
+        lua_redis.redis_make_request_taskless(ev_base,
           rspamd_config,
           rule.redis,
           nil,
@@ -840,21 +842,21 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
     end
   end
 
-  -- First we need to get all fanns stored in our Redis
-  rspamd_redis.redis_make_request_taskless(ev_base,
+  -- First we need to get all anns stored in our Redis
+  lua_redis.redis_make_request_taskless(ev_base,
     rspamd_config,
     rule.redis,
     nil,
     false, -- is write
     members_cb, --callback
     'SMEMBERS', -- command
-    {gen_fann_prefix(rule, nil)} -- arguments
+    {gen_ann_prefix(rule, nil)} -- arguments
   )
 
   return rule.watch_interval
 end
 
-local function check_fanns(rule, _, ev_base)
+local function check_anns(rule, _, ev_base)
   local function members_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
@@ -867,7 +869,7 @@ local function check_fanns(rule, _, ev_base)
             rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
               elt, _err)
           elseif _data and type(_data) == 'table' then
-            load_or_invalidate_fann(rule, _data, elt, ev_base)
+            load_or_invalidate_ann(rule, _data, elt, ev_base)
           else
             if type(_data) ~= 'number' then
               rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
@@ -877,29 +879,29 @@ local function check_fanns(rule, _, ev_base)
         end
 
         local local_ver = 0
-        if fanns[elt] then
-          if fanns[elt].version then
-            local_ver = fanns[elt].version
+        if anns[elt] then
+          if anns[elt].version then
+            local_ver = anns[elt].version
           end
         end
-        rspamd_redis.exec_redis_script(redis_maybe_load_id,
+        lua_redis.exec_redis_script(redis_maybe_load_id,
           {ev_base = ev_base, is_write = false},
           redis_update_cb,
-          {gen_fann_prefix(rule, elt), tostring(local_ver)})
+          {gen_ann_prefix(rule, elt), tostring(local_ver)})
       end,
       data)
     end
   end
 
-  -- First we need to get all fanns stored in our Redis
-  rspamd_redis.redis_make_request_taskless(ev_base,
+  -- First we need to get all anns stored in our Redis
+  lua_redis.redis_make_request_taskless(ev_base,
     rspamd_config,
     rule.redis,
     nil,
     false, -- is write
     members_cb, --callback
     'SMEMBERS', -- command
-    {gen_fann_prefix(rule, nil)} -- arguments
+    {gen_ann_prefix(rule, nil)} -- arguments
   )
 
   return rule.watch_interval
@@ -916,11 +918,15 @@ local function ann_push_vector(task)
       local r = task:get_principal_recipient()
       sid = sid .. r
     end
-    fann_train_callback(rule, task, scores[1], scores[2], sid)
+    ann_train_callback(rule, task, scores[1], scores[2], sid)
   end
 end
 
-redis_params = rspamd_parse_redis_server('fann_redis')
+redis_params = lua_redis.parse_redis_server('neural')
+
+if not redis_params then
+  redis_params = lua_redis.parse_redis_server('fann_redis')
+end
 
 -- Initialization part
 if not (opts and type(opts) == 'table') or not redis_params then
@@ -929,8 +935,8 @@ if not (opts and type(opts) == 'table') or not redis_params then
   return
 end
 
-if not rspamd_fann.is_enabled() then
-  rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
+if not rspamd_fann.is_enabled() and not use_torch then
+  rspamd_logger.errx(rspamd_config, 'neural networks support is not compiled in rspamd, this ' ..
     'module is eventually disabled')
   lua_util.disable_module(N, "fail")
   return
@@ -951,7 +957,7 @@ else
     name = 'FANN_CHECK',
     type = 'postfilter,nostat',
     priority = 6,
-    callback = fann_scores_filter
+    callback = ann_scores_filter
   })
 
   local function deepcopy(orig)
@@ -1002,7 +1008,7 @@ else
       name = def_rules.symbol_spam,
       score = 3.0,
       description = 'Neural network SPAM',
-      group = 'fann'
+      group = 'neural'
     })
     rspamd_config:register_symbol({
       name = def_rules.symbol_spam,
@@ -1014,7 +1020,7 @@ else
       name = def_rules.symbol_ham,
       score = -2.0,
       description = 'Neural network HAM',
-      group = 'fann'
+      group = 'neural'
     })
     rspamd_config:register_symbol({
       name = def_rules.symbol_ham,
@@ -1034,13 +1040,13 @@ else
   for _,rule in pairs(settings.rules) do
     load_scripts(rule.redis)
     rspamd_config:add_on_load(function(cfg, ev_base, worker)
-      check_fanns(rule, cfg, ev_base)
+      check_anns(rule, cfg, ev_base)
 
       if worker:is_primary_controller() then
         -- We also want to train neural nets when they have enough data
         rspamd_config:add_periodic(ev_base, 0.0,
           function(_, _)
-            return maybe_train_fanns(rule, cfg, ev_base, worker)
+            return maybe_train_anns(rule, cfg, ev_base, worker)
           end)
       end
     end)