]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Neural: Add nan check and extensive logging
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 18 Oct 2019 16:18:26 +0000 (17:18 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 18 Oct 2019 16:18:26 +0000 (17:18 +0100)
src/plugins/lua/neural.lua

index 1ff1f40d71c159885371325177333bb3d6a5172a..e6ffe41bebb11bb10bed90be8b30aed9277e1b8e 100644 (file)
@@ -564,7 +564,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
     local inputs, outputs = {}, {}
 
     -- Used to show sparsed vectors in a convenient format (for debugging only)
-    --[[
     local function debug_vec(t)
       local ret = {}
       for i,v in ipairs(t) do
@@ -575,7 +574,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
 
       return ret
     end
-    ]]--
 
     -- Make training set by joining vectors
     -- KANN automatically shuffles those samples
@@ -595,22 +593,44 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
     -- Called in child process
     local function train()
       local log_thresh = rule.train.max_iterations / 10
-      train_ann:train1(inputs, outputs, {
-        lr = rule.train.learning_rate,
-        max_epoch = rule.train.max_iterations,
-        cb = function(iter, train_cost, _)
-          if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
-            rspamd_logger.infox(rspamd_config,
-                "ANN %s:%s: learned from %s redis key in %s iterations, error: %s",
+      local seen_nan = false
+
+      local function train_cb(iter, train_cost, value_cost)
+        if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
+          if train_cost ~= train_cost and not seen_nan then
+            -- We have nan :( try to log lot's of stuff to dig into a problem
+            seen_nan = true
+            rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
                 rule.prefix, set.name,
-                ann_key,
-                iter, train_cost)
+                value_cost)
+            for i,e in ipairs(inputs) do
+              lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+                  debug_vec(e), outputs[i][1])
+            end
           end
+
+          rspamd_logger.infox(rspamd_config,
+              "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+              rule.prefix, set.name,
+              ann_key,
+              iter,
+              train_cost,
+              value_cost)
         end
+      end
+
+      train_ann:train1(inputs, outputs, {
+        lr = rule.train.learning_rate,
+        max_epoch = rule.train.max_iterations,
+        cb = train_cb,
       })
 
-      local out = train_ann:save()
-      return out
+      if not seen_nan then
+        local out = train_ann:save()
+        return out
+      else
+        return nil
+      end
     end
 
     set.learning_spawned = true