From a3b2c0f9db42a0b6d4d68d48654367e5b17b892a Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 28 Mar 2019 15:53:17 +0000 Subject: [PATCH] [Minor] Fix stupid torch that uses `print` for logging --- contrib/lua-torch/nn/StochasticGradient.lua | 11 +++++++---- src/plugins/lua/neural.lua | 6 ++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/contrib/lua-torch/nn/StochasticGradient.lua b/contrib/lua-torch/nn/StochasticGradient.lua index a060371e8..dc80be1b1 100644 --- a/contrib/lua-torch/nn/StochasticGradient.lua +++ b/contrib/lua-torch/nn/StochasticGradient.lua @@ -8,6 +8,9 @@ function StochasticGradient:__init(module, criterion) self.module = module self.criterion = criterion self.verbose = true + self.logger = function(s) + print(s) + end end function StochasticGradient:train(dataset) @@ -23,7 +26,7 @@ function StochasticGradient:train(dataset) end end - print("# StochasticGradient: training") + self.logger("# StochasticGradient: training") while true do local currentError = 0 @@ -49,13 +52,13 @@ function StochasticGradient:train(dataset) end if self.verbose then - print("# current error = " .. currentError) + self.logger("# current error = " .. currentError) end iteration = iteration + 1 currentLearningRate = self.learningRate/(1+iteration*self.learningRateDecay) if self.maxIteration > 0 and iteration > self.maxIteration then - print("# StochasticGradient: you have reached the maximum number of iterations") - print("# training error = " .. currentError) + self.logger("# StochasticGradient: you have reached the maximum number of iterations") + self.logger("# training error = " .. currentError) break end end diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index b75adf468..30c4fee0f 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -671,11 +671,13 @@ local function train_ann(rule, _, ev_base, elt, worker) trainer.learning_rate = rule.train.learning_rate trainer.verbose = false trainer.maxIteration = rule.train.max_iterations - trainer.hookIteration = function(self, iteration, currentError) + trainer.hookIteration = function(_, iteration, currentError) rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s", iteration, currentError) end - + trainer.logger = function(s) + rspamd_logger.infox(rspamd_config, 'training: %s', s) + end trainer:train(dataset) local out = torch.MemoryFile() out:writeObject(rule.anns[elt].ann_train) -- 2.39.5