From 586560e919220191087eebb31938ef69c72a1223 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 9 Mar 2018 17:05:03 +0000 Subject: [PATCH] [Feature] Implement l1/l2 regularization against the current weights --- lualib/rspamadm/rescore.lua | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index fb1428694..c8348caa3 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -210,7 +210,7 @@ end -- training function local function train(dataset, opt, model, criterion, epoch, - all_symbols, spam_threshold) + all_symbols, spam_threshold, initial_weights) -- epoch tracker epoch = epoch or 1 @@ -270,16 +270,18 @@ local function train(dataset, opt, model, criterion, epoch, -- penalties (L1 and L2): local l1 = tonumber(opt.l1) or 0 local l2 = tonumber(opt.l2) or 0 + if l1 ~= 0 or l2 ~= 0 then -- locals: local norm,sign= torch.norm,torch.sign + local diff = parameters - initial_weights -- Loss: - f = f + l1 * norm(parameters,1) - f = f + l2 * norm(parameters,2)^2/2 + f = f + l1 * norm(diff,1) + f = f + l2 * norm(diff,2)^2/2 -- Gradients: - gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) ) + gradParameters:add( sign(diff):mul(l1) + diff:clone():mul(l2) ) end -- update confusion @@ -492,10 +494,12 @@ return function (args, cfg) for _,lr in ipairs(learning_rates) do for _,wd in ipairs(penalty_weights) do linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) + local initial_weights = linear_module.weight[1]:clone() opts.learning_rate = lr opts.weight_decay = wd for i=1,tonumber(opts.iters) do - train(dataset, opts, perceptron, criterion, i, all_symbols, threshold) + train(dataset, opts, perceptron, criterion, i, all_symbols, threshold, + initial_weights) end local fscore = calculate_fscore_from_weights(cv_logs, -- 2.39.5