]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Fix stupid torch that uses `print` for logging
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 28 Mar 2019 15:53:17 +0000 (15:53 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 28 Mar 2019 15:53:17 +0000 (15:53 +0000)
contrib/lua-torch/nn/StochasticGradient.lua
src/plugins/lua/neural.lua

index a060371e85e7fb7b5dd21a2d95ca0272587ced43..dc80be1b1566d06ba95accd8edf5df5ed4f978fe 100644 (file)
@@ -8,6 +8,9 @@ function StochasticGradient:__init(module, criterion)
    self.module = module
    self.criterion = criterion
    self.verbose = true
+   self.logger = function(s)
+      print(s)
+   end
 end
 
 function StochasticGradient:train(dataset)
@@ -23,7 +26,7 @@ function StochasticGradient:train(dataset)
       end
    end
 
-   print("# StochasticGradient: training")
+   self.logger("# StochasticGradient: training")
 
    while true do
       local currentError = 0
@@ -49,13 +52,13 @@ function StochasticGradient:train(dataset)
       end
 
       if self.verbose then
-         print("# current error = " .. currentError)
+         self.logger("# current error = " .. currentError)
       end
       iteration = iteration + 1
       currentLearningRate = self.learningRate/(1+iteration*self.learningRateDecay)
       if self.maxIteration > 0 and iteration > self.maxIteration then
-         print("# StochasticGradient: you have reached the maximum number of iterations")
-         print("# training error = " .. currentError)
+         self.logger("# StochasticGradient: you have reached the maximum number of iterations")
+         self.logger("# training error = " .. currentError)
          break
       end
    end
index b75adf4689d5ff6255feb337b8bdda4871252792..30c4fee0f8898bda98f8eaab5432728e81afb67d 100644 (file)
@@ -671,11 +671,13 @@ local function train_ann(rule, _, ev_base, elt, worker)
             trainer.learning_rate = rule.train.learning_rate
             trainer.verbose = false
             trainer.maxIteration = rule.train.max_iterations
-            trainer.hookIteration = function(self, iteration, currentError)
+            trainer.hookIteration = function(_, iteration, currentError)
               rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
                   iteration, currentError)
             end
-
+            trainer.logger = function(s)
+              rspamd_logger.infox(rspamd_config, 'training: %s', s)
+            end
             trainer:train(dataset)
             local out = torch.MemoryFile()
             out:writeObject(rule.anns[elt].ann_train)