diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-09 17:05:03 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-09 17:05:49 +0000 |
commit | 586560e919220191087eebb31938ef69c72a1223 (patch) | |
tree | ecc1997c1b3f2b8c77fc5827707941ad029d23bd /lualib/rspamadm | |
parent | 64d4efdb4d040112eec10691bd197ceece33923f (diff) | |
download | rspamd-586560e919220191087eebb31938ef69c72a1223.tar.gz rspamd-586560e919220191087eebb31938ef69c72a1223.zip |
[Feature] Implement l1/l2 regularization against the current weights
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r-- | lualib/rspamadm/rescore.lua | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index fb1428694..c8348caa3 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -210,7 +210,7 @@ end -- training function local function train(dataset, opt, model, criterion, epoch, - all_symbols, spam_threshold) + all_symbols, spam_threshold, initial_weights) -- epoch tracker epoch = epoch or 1 @@ -270,16 +270,18 @@ local function train(dataset, opt, model, criterion, epoch, -- penalties (L1 and L2): local l1 = tonumber(opt.l1) or 0 local l2 = tonumber(opt.l2) or 0 + if l1 ~= 0 or l2 ~= 0 then -- locals: local norm,sign= torch.norm,torch.sign + local diff = parameters - initial_weights -- Loss: - f = f + l1 * norm(parameters,1) - f = f + l2 * norm(parameters,2)^2/2 + f = f + l1 * norm(diff,1) + f = f + l2 * norm(diff,2)^2/2 -- Gradients: - gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) ) + gradParameters:add( sign(diff):mul(l1) + diff:clone():mul(l2) ) end -- update confusion @@ -492,10 +494,12 @@ return function (args, cfg) for _,lr in ipairs(learning_rates) do for _,wd in ipairs(penalty_weights) do linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) + local initial_weights = linear_module.weight[1]:clone() opts.learning_rate = lr opts.weight_decay = wd for i=1,tonumber(opts.iters) do - train(dataset, opts, perceptron, criterion, i, all_symbols, threshold) + train(dataset, opts, perceptron, criterion, i, all_symbols, threshold, + initial_weights) end local fscore = calculate_fscore_from_weights(cv_logs, |