]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Various fixes
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 7 Jul 2019 11:03:22 +0000 (12:03 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 7 Jul 2019 11:03:22 +0000 (12:03 +0100)
src/plugins/lua/neural.lua

index 750e5a359c94f7609838ca86c6285aab7b655c9a..d2c0191e7a8a40159a41dedc266606340510c93b 100644 (file)
@@ -76,10 +76,10 @@ local settings = {
   max_profiles = 3, -- Maximum number of NN profiles stored
 }
 
-local opts = rspamd_config:get_all_opt("neural")
-if not opts then
+local module_config = rspamd_config:get_all_opt("neural")
+if not module_config then
   -- Legacy
-  opts = rspamd_config:get_all_opt("fann_redis")
+  module_config = rspamd_config:get_all_opt("fann_redis")
 end
 
 
@@ -941,7 +941,7 @@ local function load_ann_profile(element)
 end
 
 -- Function to check or load ANNs from Redis
-local function check_anns(worker, rule, cfg, ev_base, process_callback)
+local function check_anns(worker, cfg, ev_base, rule, process_callback)
   for _,set in pairs(rule.settings) do
     local function members_cb(err, data)
       if err then
@@ -1039,10 +1039,10 @@ local function process_rules_settings()
             rule.prefix, selt.name), {persistent = true})
   end
 
-  for _,rule in pairs(opts.rules) do
+  for _,rule in pairs(settings.rules) do
     if not rule.allowed_settings then
       -- Extract all settings ids
-      rule.allowed_settings = lua_util.keys(lua_settings.all_settings)
+      rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
     end
 
     -- Convert to a map <setting_id> -> true
@@ -1057,7 +1057,7 @@ local function process_rules_settings()
 
     if rule.default then
       local default_settings = {
-        symbols = lua_util.keys(lua_settings.default_symbols),
+        symbols = lua_util.keys(lua_settings.default_symbols()),
         name = 'default'
       }
 
@@ -1080,6 +1080,8 @@ local function process_rules_settings()
       for id,ex in pairs(rule.settings) do
         if lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
           -- Equal symbols, add reference
+          lua_util.debugm(N, rspamd_config, 'added reference from settings id %s to %s; same symbols',
+              nelt.name, ex.name)
           rule.settings[s] = id
           nelt = nil
         end
@@ -1087,6 +1089,8 @@ local function process_rules_settings()
 
       if nelt then
         rule.settings[s] = nelt
+        lua_util.debugm(N, rspamd_config, 'added new settings id %s to %s; same symbols',
+            nelt.name, rule.prefix)
       end
     end
   end
@@ -1099,18 +1103,18 @@ if not redis_params then
 end
 
 -- Initialization part
-if not (opts and type(opts) == 'table') or not redis_params then
+if not (module_config and type(module_config) == 'table') or not redis_params then
   rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
   lua_util.disable_module(N, "redis")
   return
 end
 
-local rules = opts['rules']
+local rules = module_config['rules']
 
 if not rules then
   -- Use legacy configuration
   rules = {}
-  rules['default'] = opts
+  rules['default'] = module_config
 end
 
 local id = rspamd_config:register_symbol({
@@ -1121,42 +1125,44 @@ local id = rspamd_config:register_symbol({
 })
 
 for k,r in pairs(rules) do
-  local def_rules = lua_util.override_defaults(default_options, r)
-  def_rules['redis'] = redis_params
-  def_rules['anns'] = {} -- Store ANNs here
+  local rule_elt = lua_util.override_defaults(default_options, r)
+  rule_elt['redis'] = redis_params
+  rule_elt['anns'] = {} -- Store ANNs here
 
-  if not def_rules.prefix then
-    def_rules.prefix = k
+  if not rule_elt.prefix then
+    rule_elt.prefix = k
   end
-  if not def_rules.name then
-    def_rules.name = k
+  if not rule_elt.name then
+    rule_elt.name = k
   end
-  if def_rules.train.max_train then
-    def_rules.train.max_trains = def_rules.train.max_train
+  if rule_elt.train.max_train then
+    rule_elt.train.max_trains = rule_elt.train.max_train
   end
 
+  if not rule_elt.profile then rule_elt.profile = {} end
+
   rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
-  settings.rules[k] = def_rules
+  settings.rules[k] = rule_elt
   rspamd_config:set_metric_symbol({
-    name = def_rules.symbol_spam,
+    name = rule_elt.symbol_spam,
     score = 0.0,
     description = 'Neural network SPAM',
     group = 'neural'
   })
   rspamd_config:register_symbol({
-    name = def_rules.symbol_spam,
+    name = rule_elt.symbol_spam,
     type = 'virtual,nostat',
     parent = id
   })
 
   rspamd_config:set_metric_symbol({
-    name = def_rules.symbol_ham,
+    name = rule_elt.symbol_ham,
     score = -0.0,
     description = 'Neural network HAM',
     group = 'neural'
   })
   rspamd_config:register_symbol({
-    name = def_rules.symbol_ham,
+    name = rule_elt.symbol_ham,
     type = 'virtual,nostat',
     parent = id
   })