aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lualib/rspamadm/rescore.lua14
1 files changed, 9 insertions, 5 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index fb1428694..c8348caa3 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -210,7 +210,7 @@ end
-- training function
local function train(dataset, opt, model, criterion, epoch,
- all_symbols, spam_threshold)
+ all_symbols, spam_threshold, initial_weights)
-- epoch tracker
epoch = epoch or 1
@@ -270,16 +270,18 @@ local function train(dataset, opt, model, criterion, epoch,
-- penalties (L1 and L2):
local l1 = tonumber(opt.l1) or 0
local l2 = tonumber(opt.l2) or 0
+
if l1 ~= 0 or l2 ~= 0 then
-- locals:
local norm,sign= torch.norm,torch.sign
+ local diff = parameters - initial_weights
-- Loss:
- f = f + l1 * norm(parameters,1)
- f = f + l2 * norm(parameters,2)^2/2
+ f = f + l1 * norm(diff,1)
+ f = f + l2 * norm(diff,2)^2/2
-- Gradients:
- gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) )
+ gradParameters:add( sign(diff):mul(l1) + diff:clone():mul(l2) )
end
-- update confusion
@@ -492,10 +494,12 @@ return function (args, cfg)
for _,lr in ipairs(learning_rates) do
for _,wd in ipairs(penalty_weights) do
linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+ local initial_weights = linear_module.weight[1]:clone()
opts.learning_rate = lr
opts.weight_decay = wd
for i=1,tonumber(opts.iters) do
- train(dataset, opts, perceptron, criterion, i, all_symbols, threshold)
+ train(dataset, opts, perceptron, criterion, i, all_symbols, threshold,
+ initial_weights)
end
local fscore = calculate_fscore_from_weights(cv_logs,