]> source.dussan.org Git - rspamd.git/commitdiff
[CritFix] Fix multiple neural networks support
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 21 May 2018 11:57:34 +0000 (12:57 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 21 May 2018 11:57:34 +0000 (12:57 +0100)
Issue: #2252

src/plugins/lua/neural.lua

index 04b732472dea32aa43fb065f658f263a3c3504dc..c8a6f1173a806e05261192d7fe5223a5cbb4f9ec 100644 (file)
@@ -65,10 +65,6 @@ local settings = {
   rules = {}
 }
 
--- ANNs indexed by settings id
-local anns = {
-}
-
 local opts = rspamd_config:get_all_opt("neural")
 if not opts then
   -- Legacy
@@ -278,7 +274,7 @@ local function ann_scores_filter(task)
       id = id .. r
     end
 
-    if anns[id] and anns[id].ann then
+    if rule.anns[id] and rule.anns[id].ann then
       local ann_data = task:get_symbols_tokens()
       local mt = meta_functions.rspamd_gen_metatokens(task)
       -- Add filtered meta tokens
@@ -286,10 +282,10 @@ local function ann_scores_filter(task)
 
       local score
       if use_torch then
-        local out = anns[id].ann:forward(torch.Tensor(ann_data))
+        local out = rule.anns[id].ann:forward(torch.Tensor(ann_data))
         score = out[1]
       else
-        local out = anns[id].ann:test(ann_data)
+        local out = rule.anns[id].ann:test(ann_data)
         score = out[1]
       end
 
@@ -339,28 +335,29 @@ end
 
 local function create_train_ann(rule, n, id)
   local prefix = gen_ann_prefix(rule, id)
-  if not anns[id] then
-    anns[id] = {}
+  if not rule.anns[id] then
+    rule.anns[id] = {}
   end
   -- Fix that for flexibe layers number
-  if anns[id].ann then
-    if not is_ann_valid(rule, prefix, anns[id].ann) then
-      anns[id].ann_train = create_ann(n, rule.nlayers)
-      anns[id].ann = nil
+  if rule.anns[id].ann then
+    if not is_ann_valid(rule, prefix, rule.anns[id].ann) then
+      rule.anns[id].ann_train = create_ann(n, rule.nlayers)
+      rule.anns[id].ann = nil
       rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
-    elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then
+    elseif rule.train.max_usages > 0 and
+        rule.anns[id].version % rule.train.max_usages == 0 then
       -- Forget last ann
       rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
-        anns[id].version)
-      anns[id].ann_train = create_ann(n, rule.nlayers)
+          rule.anns[id].version)
+      rule.anns[id].ann_train = create_ann(n, rule.nlayers)
     else
-      anns[id].ann_train = anns[id].ann
+      rule.anns[id].ann_train = rule.anns[id].ann
       rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
     end
   else
-    anns[id].ann_train = create_ann(n, rule.nlayers)
+    rule.anns[id].ann_train = create_ann(n, rule.nlayers)
     rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
-    anns[id].version = 0
+    rule.anns[id].version = 0
   end
 end
 
@@ -388,18 +385,18 @@ local function load_or_invalidate_ann(rule, data, id, ev_base)
   end
 
   if is_ann_valid(rule, prefix, ann) then
-    if not anns[id] then anns[id] = {} end
-    anns[id].ann = ann
+    if not rule.anns[id] then rule.anns[id] = {} end
+    rule.anns[id].ann = ann
     rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
       prefix, ver)
-    anns[id].version = tonumber(ver)
+    rule.anns[id].version = tonumber(ver)
   else
     local function redis_invalidate_cb(_err, _data)
       if _err then
         rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
       elseif type(_data) == 'string' then
         rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
-        anns[id].version = 0
+        rule.anns[id].version = 0
       end
     end
     -- Invalidate ANN
@@ -553,15 +550,15 @@ local function train_ann(rule, _, ev_base, elt, worker)
       local ann_data
       if use_torch then
         local f = torch.MemoryFile()
-        f:writeObject(anns[elt].ann_train)
+        f:writeObject(rule.anns[elt].ann_train)
         ann_data = rspamd_util.zstd_compress(f:storage():string())
       else
-        ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data())
+        ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data())
       end
 
-      anns[elt].version = anns[elt].version + 1
-      anns[elt].ann = anns[elt].ann_train
-      anns[elt].ann_train = nil
+      rule.anns[elt].version = rule.anns[elt].version + 1
+      rule.anns[elt].ann = rule.anns[elt].ann_train
+      rule.anns[elt].ann_train = nil
       lua_redis.exec_redis_script(redis_save_unlock_id,
         {ev_base = ev_base, is_write = true},
         redis_save_cb,
@@ -589,11 +586,11 @@ local function train_ann(rule, _, ev_base, elt, worker)
       local ann_data
       local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
       ann_data = rspamd_util.zstd_compress(f:storage():string())
-      anns[elt].ann_train = f:readObject()
+      rule.anns[elt].ann_train = f:readObject()
 
-      anns[elt].version = anns[elt].version + 1
-      anns[elt].ann = anns[elt].ann_train
-      anns[elt].ann_train = nil
+      rule.anns[elt].version = rule.anns[elt].version + 1
+      rule.anns[elt].ann = rule.anns[elt].ann_train
+      rule.anns[elt].ann_train = nil
       lua_redis.exec_redis_script(redis_save_unlock_id,
         {ev_base = ev_base, is_write = true},
         redis_save_cb,
@@ -629,7 +626,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
       end
 
       -- Now we can train ann
-      if not anns[elt] or not anns[elt].ann_train then
+      if not rule.anns[elt] or not rule.anns[elt].ann_train then
         -- Create ann if it does not exist
         create_train_ann(rule, n, elt)
       end
@@ -641,7 +638,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
             rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
           elseif type(_data) == 'string' then
             rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
-            anns[elt].version = 0
+            rule.anns[elt].version = 0
           end
         end
         -- Invalidate ANN
@@ -668,7 +665,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
               torch.setnumthreads(rule.train.learn_threads)
             end
             local criterion = nn.MSECriterion()
-            local trainer = nn.StochasticGradient(anns[elt].ann_train,
+            local trainer = nn.StochasticGradient(rule.anns[elt].ann_train,
               criterion)
             trainer.learning_rate = rule.train.learning_rate
             trainer.verbose = false
@@ -680,7 +677,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
 
             trainer:train(dataset)
             local out = torch.MemoryFile()
-            out:writeObject(anns[elt].ann_train)
+            out:writeObject(rule.anns[elt].ann_train)
             local st = out:storage():string()
             return st
           end
@@ -701,7 +698,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
           end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
           rule.learning_spawned = true
           rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
-          anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
+          rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
             ev_base, {
               max_epochs = rule.train.max_epoch,
               desired_mse = rule.train.mse
@@ -880,9 +877,9 @@ local function check_anns(rule, _, ev_base)
         end
 
         local local_ver = 0
-        if anns[elt] then
-          if anns[elt].version then
-            local_ver = anns[elt].version
+        if rule.anns[elt] then
+          if rule.anns[elt].version then
+            local_ver = rule.anns[elt].version
           end
         end
         lua_redis.exec_redis_script(redis_maybe_load_id,
@@ -963,6 +960,7 @@ else
   for k,r in pairs(rules) do
     local def_rules = lua_util.override_defaults(default_options, r)
     def_rules['redis'] = redis_params
+    def_rules['anns'] = {} -- Store ANNs here
 
     if not def_rules.prefix then
       def_rules.prefix = k