aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-09 17:05:03 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-09 17:05:49 +0000
commit586560e919220191087eebb31938ef69c72a1223 (patch)
treeecc1997c1b3f2b8c77fc5827707941ad029d23bd
parent64d4efdb4d040112eec10691bd197ceece33923f (diff)
downloadrspamd-586560e919220191087eebb31938ef69c72a1223.tar.gz
rspamd-586560e919220191087eebb31938ef69c72a1223.zip
[Feature] Implement l1/l2 regularization against the current weights
-rw-r--r--lualib/rspamadm/rescore.lua14
1 files changed, 9 insertions, 5 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index fb1428694..c8348caa3 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -210,7 +210,7 @@ end
-- training function
local function train(dataset, opt, model, criterion, epoch,
- all_symbols, spam_threshold)
+ all_symbols, spam_threshold, initial_weights)
-- epoch tracker
epoch = epoch or 1
@@ -270,16 +270,18 @@ local function train(dataset, opt, model, criterion, epoch,
-- penalties (L1 and L2):
local l1 = tonumber(opt.l1) or 0
local l2 = tonumber(opt.l2) or 0
+
if l1 ~= 0 or l2 ~= 0 then
-- locals:
local norm,sign= torch.norm,torch.sign
+ local diff = parameters - initial_weights
-- Loss:
- f = f + l1 * norm(parameters,1)
- f = f + l2 * norm(parameters,2)^2/2
+ f = f + l1 * norm(diff,1)
+ f = f + l2 * norm(diff,2)^2/2
-- Gradients:
- gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) )
+ gradParameters:add( sign(diff):mul(l1) + diff:clone():mul(l2) )
end
-- update confusion
@@ -492,10 +494,12 @@ return function (args, cfg)
for _,lr in ipairs(learning_rates) do
for _,wd in ipairs(penalty_weights) do
linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+ local initial_weights = linear_module.weight[1]:clone()
opts.learning_rate = lr
opts.weight_decay = wd
for i=1,tonumber(opts.iters) do
- train(dataset, opts, perceptron, criterion, i, all_symbols, threshold)
+ train(dataset, opts, perceptron, criterion, i, all_symbols, threshold,
+ initial_weights)
end
local fscore = calculate_fscore_from_weights(cv_logs,