aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-03-28 15:53:17 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-03-28 15:53:17 +0000
commita3b2c0f9db42a0b6d4d68d48654367e5b17b892a (patch)
treeb56dc2372523861f7615b509a3413de1c127a26b
parent781f3162dda790d0d4d00431204ef49b326cd699 (diff)
downloadrspamd-a3b2c0f9db42a0b6d4d68d48654367e5b17b892a.tar.gz
rspamd-a3b2c0f9db42a0b6d4d68d48654367e5b17b892a.zip
[Minor] Fix stupid torch that uses `print` for logging
-rw-r--r--contrib/lua-torch/nn/StochasticGradient.lua11
-rw-r--r--src/plugins/lua/neural.lua6
2 files changed, 11 insertions, 6 deletions
diff --git a/contrib/lua-torch/nn/StochasticGradient.lua b/contrib/lua-torch/nn/StochasticGradient.lua
index a060371e8..dc80be1b1 100644
--- a/contrib/lua-torch/nn/StochasticGradient.lua
+++ b/contrib/lua-torch/nn/StochasticGradient.lua
@@ -8,6 +8,9 @@ function StochasticGradient:__init(module, criterion)
self.module = module
self.criterion = criterion
self.verbose = true
+ self.logger = function(s)
+ print(s)
+ end
end
function StochasticGradient:train(dataset)
@@ -23,7 +26,7 @@ function StochasticGradient:train(dataset)
end
end
- print("# StochasticGradient: training")
+ self.logger("# StochasticGradient: training")
while true do
local currentError = 0
@@ -49,13 +52,13 @@ function StochasticGradient:train(dataset)
end
if self.verbose then
- print("# current error = " .. currentError)
+ self.logger("# current error = " .. currentError)
end
iteration = iteration + 1
currentLearningRate = self.learningRate/(1+iteration*self.learningRateDecay)
if self.maxIteration > 0 and iteration > self.maxIteration then
- print("# StochasticGradient: you have reached the maximum number of iterations")
- print("# training error = " .. currentError)
+ self.logger("# StochasticGradient: you have reached the maximum number of iterations")
+ self.logger("# training error = " .. currentError)
break
end
end
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index b75adf468..30c4fee0f 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -671,11 +671,13 @@ local function train_ann(rule, _, ev_base, elt, worker)
trainer.learning_rate = rule.train.learning_rate
trainer.verbose = false
trainer.maxIteration = rule.train.max_iterations
- trainer.hookIteration = function(self, iteration, currentError)
+ trainer.hookIteration = function(_, iteration, currentError)
rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
iteration, currentError)
end
-
+ trainer.logger = function(s)
+ rspamd_logger.infox(rspamd_config, 'training: %s', s)
+ end
trainer:train(dataset)
local out = torch.MemoryFile()
out:writeObject(rule.anns[elt].ann_train)