diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-06 17:03:49 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-06 17:03:49 +0000 |
commit | 5e9608de5af030ff5ea0ff2303b2f6263c0749e2 (patch) | |
tree | e2ac7ecb87d5cb9f168d55b0514db86ad77a51c3 /lualib/rspamadm | |
parent | cf527ffe3e4e98e9f4c03ae70f384b292fa5d656 (diff) | |
download | rspamd-5e9608de5af030ff5ea0ff2303b2f6263c0749e2.tar.gz rspamd-5e9608de5af030ff5ea0ff2303b2f6263c0749e2.zip |
[Rework] Rewrite model and learning logic for rescore
- Add more optimization methods
- Implement l1/l2 regulation
- Improve usability
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r-- | lualib/rspamadm/rescore.lua | 210 |
1 files changed, 181 insertions, 29 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index 5ef1c4267..448c36257 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -4,6 +4,8 @@ local lua_util = require "lua_util" local ucl = require "ucl" local logger = require "rspamd_logger" local getopt = require "rspamadm/getopt" +local optim = require "optim" +local rspamd_util = require "rspamd_util" local rescore_utility = require "rspamadm/rescore_utility" @@ -60,13 +62,10 @@ local function init_weights(all_symbols, original_symbol_scores) local weights = torch.Tensor(#all_symbols) - local mean = 0 - for i, symbol in pairs(all_symbols) do local score = original_symbol_scores[symbol] if not score then score = 0 end weights[i] = score - mean = mean + score end return weights @@ -209,17 +208,175 @@ Overall accuracy: %.2f %% end +-- training function +local function train(dataset, opt, model, criterion, epoch, + all_symbols) + -- epoch tracker + epoch = epoch or 1 + + -- local vars + local time = rspamd_util.get_ticks() + local confusion = optim.ConfusionMatrix({'ham', 'spam'}) + + -- do one epoch + + local lbfgsState + local sgdState + + local batch_size = opt.batch_size + + logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size) + + for t = 1,dataset:size(),batch_size do + -- create mini batch + local k = 1 + local last = math.min(t + batch_size - 1, dataset:size()) + local inputs = torch.Tensor(last - t + 1, #all_symbols) + local targets = torch.Tensor(last - t + 1) + for i = t,last do + -- load new sample + local sample = dataset[i] + local input = sample[1]:clone() + local target = sample[2]:clone() + --target = target:squeeze() + inputs[k] = input + targets[k] = target + k = k + 1 + end + + local parameters,gradParameters = model:getParameters() + + -- create closure to evaluate f(X) and df/dX + local feval = function(x) + -- just in case: + collectgarbage() + + -- get new parameters + if x ~= parameters then + parameters:copy(x) + end + + -- reset gradients + gradParameters:zero() + + -- evaluate function for complete mini batch + local outputs = model:forward(inputs) + local f = criterion:forward(outputs, targets) + + -- estimate df/dW + local df_do = criterion:backward(outputs, targets) + model:backward(inputs, df_do) + + -- penalties (L1 and L2): + local l1 = tonumber(opt.l1) or 0 + local l2 = tonumber(opt.l1) or 0 + if l1 ~= 0 or l2 ~= 0 then + -- locals: + local norm,sign= torch.norm,torch.sign + + -- Loss: + f = f + l1 * norm(parameters,1) + f = f + l2 * norm(parameters,2)^2/2 + + -- Gradients: + gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) ) + end + + -- update confusion + for i = 1,(last - t + 1) do + local class_predicted = 0 + if outputs[i][1] > 0.5 then class_predicted = 1 end + confusion:add(class_predicted + 1, targets[i] + 1) + end + + -- return f and df/dX + return f,gradParameters + end + + -- optimize on current mini-batch + if opt.optimization == 'LBFGS' then + + -- Perform LBFGS step: + lbfgsState = lbfgsState or { + maxIter = opt.iters, + lineSearch = optim.lswolfe + } + optim.lbfgs(feval, parameters, lbfgsState) + + -- disp report: + logger.messagex('LBFGS step') + logger.messagex(' - progress in batch: ' .. t .. '/' .. dataset:size()) + logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter) + logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval) + + elseif opt.optimization == 'ADAM' then + sgdState = sgdState or { + learningRate = tonumber(opts.learning_rate),-- opt.learningRate, + momentum = tonumber(opts.momentum), -- opt.momentum, + learningRateDecay = tonumber(opts.learning_rate_decay), + weightDecay = tonumber(opts.weight_decay), + } + optim.adam(feval, parameters, sgdState) + elseif opt.optimization == 'ADAGRAD' then + sgdState = sgdState or { + learningRate = tonumber(opts.learning_rate),-- opt.learningRate, + momentum = tonumber(opts.momentum), -- opt.momentum, + learningRateDecay = tonumber(opts.learning_rate_decay), + weightDecay = tonumber(opts.weight_decay), + } + optim.adagrad(feval, parameters, sgdState) + elseif opt.optimization == 'SGD' then + sgdState = sgdState or { + learningRate = tonumber(opts.learning_rate),-- opt.learningRate, + momentum = tonumber(opts.momentum), -- opt.momentum, + learningRateDecay = tonumber(opts.learning_rate_decay), + weightDecay = tonumber(opts.weight_decay), + } + optim.sgd(feval, parameters, sgdState) + elseif opt.optimization == 'NAG' then + sgdState = sgdState or { + learningRate = tonumber(opts.learning_rate),-- opt.learningRate, + momentum = tonumber(opts.momentum), -- opt.momentum, + learningRateDecay = tonumber(opts.learning_rate_decay), + weightDecay = tonumber(opts.weight_decay), + } + optim.nag(feval, parameters, sgdState) + else + error('unknown optimization method') + end + end + + -- time taken + time = rspamd_util.get_ticks() - time + time = time / dataset:size() + logger.messagex("time to learn 1 sample = " .. (time*1000) .. 'ms') + + -- logger.messagex confusion matrix + logger.messagex('confusion: %s', tostring(confusion)) + logger.messagex('%s mean class accuracy (train set)', confusion.totalValid * 100) + confusion:zero() + + epoch = epoch + 1 +end + + local default_opts = { verbose = true, iters = 10, threads = 1, + batch_size = 1000, + optimization = 'ADAM', + learning_rate_decay = 0.001, + momentum = 0.0, + l1 = 0.0, + l2 = 0.0, } local learning_rates = { - 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10 + 0.01, 0.05, 0.1 } local penalty_weights = { - 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100 + 0, 0.001, 0.01, 0.1, 0.5 } local function override_defaults(def, override) @@ -317,13 +474,12 @@ return function (args, cfg) -- Start of perceptron training local input_size = #all_symbols torch.setnumthreads(opts['threads']) + local linear_module = nn.Linear(input_size, 1) + local activation = nn.Tanh() local perceptron = nn.Sequential() perceptron:add(linear_module) - - local activation = nn.Sigmoid() - perceptron:add(activation) local criterion = nn.MSECriterion() @@ -331,13 +487,17 @@ return function (args, cfg) local best_fscore = -math.huge local best_weights = linear_module.weight[1]:clone() + local best_learning_rate + local best_weight_decay - local trainer = nn.StochasticGradient(perceptron, criterion) - trainer.maxIteration = tonumber(opts["iters"]) - trainer.verbose = opts['verbose'] - trainer.hookIteration = function(self, iteration, error) - - if iteration == trainer.maxIteration then + for _,lr in ipairs(learning_rates) do + for _,wd in ipairs(penalty_weights) do + linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) + opts.learning_rate = lr + opts.weight_decay = wd + for i=1,tonumber(opts.iters) do + train(dataset, opts, perceptron, criterion, i, all_symbols) + end local fscore = calculate_fscore_from_weights(cv_logs, all_symbols, @@ -345,29 +505,18 @@ return function (args, cfg) linear_module.bias[1], threshold) - logger.messagex("Cross-validation fscore: %s", fscore) + logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s", + fscore, lr, wd) if best_fscore < fscore then + best_learning_rate = lr + best_weight_decay = wd best_fscore = fscore best_weights = linear_module.weight[1]:clone() end end end - for _, learning_rate in ipairs(learning_rates) do - for _, weight in ipairs(penalty_weights) do - - trainer.weightDecay = weight - logger.messagex("Learning with learning_rate: %s, l2_weight: %s", - learning_rate, weight) - - linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) - - trainer.learningRate = learning_rate - trainer:train(dataset) - end - end - -- End perceptron training local new_symbol_scores = best_weights @@ -392,4 +541,7 @@ return function (args, cfg) test_logs = update_logs(test_logs, new_symbol_scores) logger.message("\n\nPost-rescore test stats\n") print_stats(test_logs, threshold) + + logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s', + best_fscore, best_learning_rate, best_weight_decay) end
\ No newline at end of file |