aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-07 13:59:27 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-07 13:59:27 +0000
commitecce86ebf9054de2ae14afbdd2d0f17060eca331 (patch)
tree643f04957454281633f823638bab6c32b8dca568 /lualib/rspamadm
parent60ad874a29ce62464bf3668aa6bf1379e1d65599 (diff)
downloadrspamd-ecce86ebf9054de2ae14afbdd2d0f17060eca331.tar.gz
rspamd-ecce86ebf9054de2ae14afbdd2d0f17060eca331.zip
[Fix] Further fixes to rescore tool
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/rescore.lua34
1 files changed, 17 insertions, 17 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index 16b7cf8b8..ae61a58ba 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -15,7 +15,7 @@ local ignore_symbols = {
['DATE_IN_FUTURE'] = true,
}
-local function make_dataset_from_logs(logs, all_symbols)
+local function make_dataset_from_logs(logs, all_symbols, spam_score)
-- Returns a list of {input, output} for torch SGD train
local dataset = {}
@@ -125,7 +125,7 @@ local function update_logs(logs, symbol_scores)
for j=4,#log do
log[j] = log[j]:gsub("%s+", "")
- score = score + (symbol_scores[log[j ]] or 0)
+ score = score + (symbol_scores[log[j]] or 0)
end
log[2] = lua_util.round(score, 2)
@@ -174,7 +174,7 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores)
end
-local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold)
+local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold)
local new_symbol_scores = weights:clone()
@@ -210,7 +210,7 @@ end
-- training function
local function train(dataset, opt, model, criterion, epoch,
- all_symbols)
+ all_symbols, spam_threshold)
-- epoch tracker
epoch = epoch or 1
@@ -284,9 +284,10 @@ local function train(dataset, opt, model, criterion, epoch,
-- update confusion
for i = 1,(last - t + 1) do
- local class_predicted = 0
- if outputs[i][1] > 0.5 then class_predicted = 1 end
- confusion:add(class_predicted + 1, targets[i] + 1)
+ local class_predicted, target_class = 1, 1
+ if outputs[i][1] > 0.5 then class_predicted = 2 end
+ if targets[i] > 0.5 then target_class = 2 end
+ confusion:add(class_predicted, target_class)
end
-- return f and df/dX
@@ -395,16 +396,16 @@ local function get_threshold()
local actions = rspamd_config:get_all_actions()
if opts['spam-action'] then
- return actions[opts['spam-action']] or 0
- else
- return actions['add header'] or actions['rewrite subject'] or actions['reject']
+ return (actions[opts['spam-action']] or 0),actions['reject']
end
+ return (actions['add header'] or actions['rewrite subject']
+ or actions['reject']), actions['reject']
end
return function (args, cfg)
opts = default_opts
override_defaults(opts, getopt.getopt(args, 'i:'))
- local threshold = get_threshold()
+ local threshold,reject_score = get_threshold()
local logs = rescore_utility.get_all_logs(cfg["logdir"])
if opts['ignore-symbol'] then
@@ -466,22 +467,22 @@ return function (args, cfg)
local train_logs, validation_logs = split_logs(logs, 70)
local cv_logs, test_logs = split_logs(validation_logs, 50)
- local dataset = make_dataset_from_logs(train_logs, all_symbols)
+ local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score)
-- Start of perceptron training
local input_size = #all_symbols
torch.setnumthreads(opts['threads'])
- local linear_module = nn.Linear(input_size, 1)
- local activation = nn.Tanh()
+ local linear_module = nn.Linear(input_size, 1, false)
+ local activation = nn.Sigmoid()
local perceptron = nn.Sequential()
perceptron:add(linear_module)
perceptron:add(activation)
local criterion = nn.MSECriterion()
- criterion.sizeAverage = false
+ --criterion.sizeAverage = false
local best_fscore = -math.huge
local best_weights = linear_module.weight[1]:clone()
@@ -494,13 +495,12 @@ return function (args, cfg)
opts.learning_rate = lr
opts.weight_decay = wd
for i=1,tonumber(opts.iters) do
- train(dataset, opts, perceptron, criterion, i, all_symbols)
+ train(dataset, opts, perceptron, criterion, i, all_symbols, threshold)
end
local fscore = calculate_fscore_from_weights(cv_logs,
all_symbols,
linear_module.weight[1],
- linear_module.bias[1],
threshold)
logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s",