]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Implement torch based ANN learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 3 Sep 2017 11:20:36 +0000 (12:20 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 3 Sep 2017 11:20:36 +0000 (12:20 +0100)
src/plugins/lua/fann_redis.lua

index b0cbdefab42c1ef7aa2d4161c582c60bf9e7b66a..5c444e74c8de595a923e235105f358e0741640cc 100644 (file)
@@ -579,7 +579,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
   end
 end
 
-local function train_fann(rule, _, ev_base, elt)
+local function train_fann(rule, _, ev_base, elt, worker)
   local spam_elts = {}
   local ham_elts = {}
   elt = tostring(elt)
@@ -652,6 +652,43 @@ local function train_fann(rule, _, ev_base, elt)
     end
   end
 
+  local function ann_trained_torch(err, data)
+    rule.learning_spawned = false
+    if err then
+      rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
+        prefix, err)
+      rspamd_redis.redis_make_request_taskless(ev_base,
+        rspamd_config,
+        rule.redis,
+        nil,
+        true, -- is write
+        redis_unlock_cb, --callback
+        'DEL', -- command
+        {prefix .. '_locked'}
+      )
+    else
+      rspamd_logger.infox(rspamd_config, 'trained ANN %s',
+        prefix)
+      local ann_data
+      local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
+      ann_data = rspamd_util.zstd_compress(f:storage():string())
+      fanns[elt].fann_train = f:readObject()
+
+      fanns[elt].version = fanns[elt].version + 1
+      fanns[elt].fann = fanns[elt].fann_train
+      fanns[elt].fann_train = nil
+      rspamd_redis.redis_make_request_taskless(ev_base,
+        rspamd_config,
+        rule.redis,
+        nil,
+        true, -- is write
+        redis_save_cb, --callback
+        'EVALSHA', -- command
+        {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+      )
+    end
+  end
+
   local function redis_ham_cb(err, data)
     if err or type(data) ~= 'table' then
       rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
@@ -673,29 +710,19 @@ local function train_fann(rule, _, ev_base, elt)
       end, data))
 
       -- Now we need to join inputs and create the appropriate test vectors
-      local inputs = {}
-      local outputs = {}
-
       local n = rspamd_config:get_symbols_count() +
           meta_functions.rspamd_count_metatokens()
       local filt = function(elts)
         return #elts == n
       end
 
-      fun.each(function(spam_sample, ham_sample)
-        table.insert(inputs, spam_sample)
-        table.insert(outputs, {1.0})
-        table.insert(inputs, ham_sample)
-        table.insert(outputs, {-1.0})
-      end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
-
       -- Now we can train fann
       if not fanns[elt] or not fanns[elt].fann_train then
         -- Create fann if it does not exist
         create_train_fann(rule, n, elt)
       end
 
-      if #inputs < rule.max_trains / 2 then
+      if #spam_elts + #ham_elts < rule.max_trains / 2 then
         -- Invalidate ANN as it is definitely invalid
         local function redis_invalidate_cb(_err, _data)
           if _err then
@@ -717,13 +744,58 @@ local function train_fann(rule, _, ev_base, elt)
           {redis_locked_invalidate_sha, 1, prefix}
         )
       else
-        rule.learning_spawned = true
-        rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
-        fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
-          ev_base, {
-            max_epochs = rule.train.max_epoch,
-            desired_mse = rule.train.mse
-          })
+        if torch then
+          -- For torch we do not need to mix samples as they would be flushed
+          local dataset = {}
+          fun.each(function(s)
+            table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})})
+          end, spam_elts)
+          fun.each(function(s)
+            table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})})
+          end, ham_elts)
+          -- Needed for torch
+          dataset.size = function(tbl) return #tbl end
+
+          local function train_torch()
+            local criterion = nn.MSECriterion()
+            local trainer = nn.StochasticGradient(fanns[elt].fann_train,
+              criterion)
+            trainer.learning_rate = 0.01
+            trainer.hookIteration = function(self, iteration, currentError)
+              rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
+                  iteration, currentError)
+            end
+
+            trainer:train(dataset)
+            local out = torch.MemoryFile()
+            out:writeObject(fanns[elt].fann_train)
+            local st = out:storage():string()
+            return out
+          end
+
+          worker:spawn_process{
+            func = train_torch,
+            on_complete = ann_trained_torch,
+          }
+        else
+          local inputs = {}
+          local outputs = {}
+
+          fun.each(function(spam_sample, ham_sample)
+            table.insert(inputs, spam_sample)
+            table.insert(outputs, {1.0})
+            table.insert(inputs, ham_sample)
+            table.insert(outputs, {-1.0})
+          end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
+          rule.learning_spawned = true
+          rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
+          fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
+            ev_base, {
+              max_epochs = rule.train.max_epoch,
+              desired_mse = rule.train.mse
+            })
+        end
+
       end
     end
   end
@@ -827,7 +899,7 @@ local function train_fann(rule, _, ev_base, elt)
   )
 end
 
-local function maybe_train_fanns(rule, cfg, ev_base)
+local function maybe_train_fanns(rule, cfg, ev_base, worker)
   local function members_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
@@ -844,7 +916,7 @@ local function maybe_train_fanns(rule, cfg, ev_base)
               rspamd_logger.infox(rspamd_config,
                 'need to learn ANN %s after %s learn vectors (%s required)',
                 prefix, tonumber(_data), rule.max_trains)
-              train_fann(rule, cfg, ev_base, elt)
+              train_fann(rule, cfg, ev_base, elt, worker)
             end
           end
         end
@@ -1032,7 +1104,7 @@ else
         -- We also want to train neural nets when they have enough data
         rspamd_config:add_periodic(ev_base, 0.0,
           function(_, _)
-            return maybe_train_fanns(rule, cfg, ev_base)
+            return maybe_train_fanns(rule, cfg, ev_base, worker)
           end)
       end
     end)