aboutsummaryrefslogtreecommitdiffstats
path: root/lualib
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-06 17:03:49 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-06 17:03:49 +0000
commit5e9608de5af030ff5ea0ff2303b2f6263c0749e2 (patch)
treee2ac7ecb87d5cb9f168d55b0514db86ad77a51c3 /lualib
parentcf527ffe3e4e98e9f4c03ae70f384b292fa5d656 (diff)
downloadrspamd-5e9608de5af030ff5ea0ff2303b2f6263c0749e2.tar.gz
rspamd-5e9608de5af030ff5ea0ff2303b2f6263c0749e2.zip
[Rework] Rewrite model and learning logic for rescore
- Add more optimization methods - Implement l1/l2 regulation - Improve usability
Diffstat (limited to 'lualib')
-rw-r--r--lualib/rspamadm/rescore.lua210
1 files changed, 181 insertions, 29 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index 5ef1c4267..448c36257 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -4,6 +4,8 @@ local lua_util = require "lua_util"
local ucl = require "ucl"
local logger = require "rspamd_logger"
local getopt = require "rspamadm/getopt"
+local optim = require "optim"
+local rspamd_util = require "rspamd_util"
local rescore_utility = require "rspamadm/rescore_utility"
@@ -60,13 +62,10 @@ local function init_weights(all_symbols, original_symbol_scores)
local weights = torch.Tensor(#all_symbols)
- local mean = 0
-
for i, symbol in pairs(all_symbols) do
local score = original_symbol_scores[symbol]
if not score then score = 0 end
weights[i] = score
- mean = mean + score
end
return weights
@@ -209,17 +208,175 @@ Overall accuracy: %.2f %%
end
+-- training function
+local function train(dataset, opt, model, criterion, epoch,
+ all_symbols)
+ -- epoch tracker
+ epoch = epoch or 1
+
+ -- local vars
+ local time = rspamd_util.get_ticks()
+ local confusion = optim.ConfusionMatrix({'ham', 'spam'})
+
+ -- do one epoch
+
+ local lbfgsState
+ local sgdState
+
+ local batch_size = opt.batch_size
+
+ logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size)
+
+ for t = 1,dataset:size(),batch_size do
+ -- create mini batch
+ local k = 1
+ local last = math.min(t + batch_size - 1, dataset:size())
+ local inputs = torch.Tensor(last - t + 1, #all_symbols)
+ local targets = torch.Tensor(last - t + 1)
+ for i = t,last do
+ -- load new sample
+ local sample = dataset[i]
+ local input = sample[1]:clone()
+ local target = sample[2]:clone()
+ --target = target:squeeze()
+ inputs[k] = input
+ targets[k] = target
+ k = k + 1
+ end
+
+ local parameters,gradParameters = model:getParameters()
+
+ -- create closure to evaluate f(X) and df/dX
+ local feval = function(x)
+ -- just in case:
+ collectgarbage()
+
+ -- get new parameters
+ if x ~= parameters then
+ parameters:copy(x)
+ end
+
+ -- reset gradients
+ gradParameters:zero()
+
+ -- evaluate function for complete mini batch
+ local outputs = model:forward(inputs)
+ local f = criterion:forward(outputs, targets)
+
+ -- estimate df/dW
+ local df_do = criterion:backward(outputs, targets)
+ model:backward(inputs, df_do)
+
+ -- penalties (L1 and L2):
+ local l1 = tonumber(opt.l1) or 0
+ local l2 = tonumber(opt.l1) or 0
+ if l1 ~= 0 or l2 ~= 0 then
+ -- locals:
+ local norm,sign= torch.norm,torch.sign
+
+ -- Loss:
+ f = f + l1 * norm(parameters,1)
+ f = f + l2 * norm(parameters,2)^2/2
+
+ -- Gradients:
+ gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) )
+ end
+
+ -- update confusion
+ for i = 1,(last - t + 1) do
+ local class_predicted = 0
+ if outputs[i][1] > 0.5 then class_predicted = 1 end
+ confusion:add(class_predicted + 1, targets[i] + 1)
+ end
+
+ -- return f and df/dX
+ return f,gradParameters
+ end
+
+ -- optimize on current mini-batch
+ if opt.optimization == 'LBFGS' then
+
+ -- Perform LBFGS step:
+ lbfgsState = lbfgsState or {
+ maxIter = opt.iters,
+ lineSearch = optim.lswolfe
+ }
+ optim.lbfgs(feval, parameters, lbfgsState)
+
+ -- disp report:
+ logger.messagex('LBFGS step')
+ logger.messagex(' - progress in batch: ' .. t .. '/' .. dataset:size())
+ logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter)
+ logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval)
+
+ elseif opt.optimization == 'ADAM' then
+ sgdState = sgdState or {
+ learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+ momentum = tonumber(opts.momentum), -- opt.momentum,
+ learningRateDecay = tonumber(opts.learning_rate_decay),
+ weightDecay = tonumber(opts.weight_decay),
+ }
+ optim.adam(feval, parameters, sgdState)
+ elseif opt.optimization == 'ADAGRAD' then
+ sgdState = sgdState or {
+ learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+ momentum = tonumber(opts.momentum), -- opt.momentum,
+ learningRateDecay = tonumber(opts.learning_rate_decay),
+ weightDecay = tonumber(opts.weight_decay),
+ }
+ optim.adagrad(feval, parameters, sgdState)
+ elseif opt.optimization == 'SGD' then
+ sgdState = sgdState or {
+ learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+ momentum = tonumber(opts.momentum), -- opt.momentum,
+ learningRateDecay = tonumber(opts.learning_rate_decay),
+ weightDecay = tonumber(opts.weight_decay),
+ }
+ optim.sgd(feval, parameters, sgdState)
+ elseif opt.optimization == 'NAG' then
+ sgdState = sgdState or {
+ learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+ momentum = tonumber(opts.momentum), -- opt.momentum,
+ learningRateDecay = tonumber(opts.learning_rate_decay),
+ weightDecay = tonumber(opts.weight_decay),
+ }
+ optim.nag(feval, parameters, sgdState)
+ else
+ error('unknown optimization method')
+ end
+ end
+
+ -- time taken
+ time = rspamd_util.get_ticks() - time
+ time = time / dataset:size()
+ logger.messagex("time to learn 1 sample = " .. (time*1000) .. 'ms')
+
+ -- logger.messagex confusion matrix
+ logger.messagex('confusion: %s', tostring(confusion))
+ logger.messagex('%s mean class accuracy (train set)', confusion.totalValid * 100)
+ confusion:zero()
+
+ epoch = epoch + 1
+end
+
+
local default_opts = {
verbose = true,
iters = 10,
threads = 1,
+ batch_size = 1000,
+ optimization = 'ADAM',
+ learning_rate_decay = 0.001,
+ momentum = 0.0,
+ l1 = 0.0,
+ l2 = 0.0,
}
local learning_rates = {
- 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
+ 0.01, 0.05, 0.1
}
local penalty_weights = {
- 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
+ 0, 0.001, 0.01, 0.1, 0.5
}
local function override_defaults(def, override)
@@ -317,13 +474,12 @@ return function (args, cfg)
-- Start of perceptron training
local input_size = #all_symbols
torch.setnumthreads(opts['threads'])
+
local linear_module = nn.Linear(input_size, 1)
+ local activation = nn.Tanh()
local perceptron = nn.Sequential()
perceptron:add(linear_module)
-
- local activation = nn.Sigmoid()
-
perceptron:add(activation)
local criterion = nn.MSECriterion()
@@ -331,13 +487,17 @@ return function (args, cfg)
local best_fscore = -math.huge
local best_weights = linear_module.weight[1]:clone()
+ local best_learning_rate
+ local best_weight_decay
- local trainer = nn.StochasticGradient(perceptron, criterion)
- trainer.maxIteration = tonumber(opts["iters"])
- trainer.verbose = opts['verbose']
- trainer.hookIteration = function(self, iteration, error)
-
- if iteration == trainer.maxIteration then
+ for _,lr in ipairs(learning_rates) do
+ for _,wd in ipairs(penalty_weights) do
+ linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+ opts.learning_rate = lr
+ opts.weight_decay = wd
+ for i=1,tonumber(opts.iters) do
+ train(dataset, opts, perceptron, criterion, i, all_symbols)
+ end
local fscore = calculate_fscore_from_weights(cv_logs,
all_symbols,
@@ -345,29 +505,18 @@ return function (args, cfg)
linear_module.bias[1],
threshold)
- logger.messagex("Cross-validation fscore: %s", fscore)
+ logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s",
+ fscore, lr, wd)
if best_fscore < fscore then
+ best_learning_rate = lr
+ best_weight_decay = wd
best_fscore = fscore
best_weights = linear_module.weight[1]:clone()
end
end
end
- for _, learning_rate in ipairs(learning_rates) do
- for _, weight in ipairs(penalty_weights) do
-
- trainer.weightDecay = weight
- logger.messagex("Learning with learning_rate: %s, l2_weight: %s",
- learning_rate, weight)
-
- linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
-
- trainer.learningRate = learning_rate
- trainer:train(dataset)
- end
- end
-
-- End perceptron training
local new_symbol_scores = best_weights
@@ -392,4 +541,7 @@ return function (args, cfg)
test_logs = update_logs(test_logs, new_symbol_scores)
logger.message("\n\nPost-rescore test stats\n")
print_stats(test_logs, threshold)
+
+ logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
+ best_fscore, best_learning_rate, best_weight_decay)
end \ No newline at end of file