diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-07 13:59:27 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-07 13:59:27 +0000 |
commit | ecce86ebf9054de2ae14afbdd2d0f17060eca331 (patch) | |
tree | 643f04957454281633f823638bab6c32b8dca568 /lualib/rspamadm | |
parent | 60ad874a29ce62464bf3668aa6bf1379e1d65599 (diff) | |
download | rspamd-ecce86ebf9054de2ae14afbdd2d0f17060eca331.tar.gz rspamd-ecce86ebf9054de2ae14afbdd2d0f17060eca331.zip |
[Fix] Further fixes to rescore tool
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r-- | lualib/rspamadm/rescore.lua | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index 16b7cf8b8..ae61a58ba 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -15,7 +15,7 @@ local ignore_symbols = { ['DATE_IN_FUTURE'] = true, } -local function make_dataset_from_logs(logs, all_symbols) +local function make_dataset_from_logs(logs, all_symbols, spam_score) -- Returns a list of {input, output} for torch SGD train local dataset = {} @@ -125,7 +125,7 @@ local function update_logs(logs, symbol_scores) for j=4,#log do log[j] = log[j]:gsub("%s+", "") - score = score + (symbol_scores[log[j ]] or 0) + score = score + (symbol_scores[log[j]] or 0) end log[2] = lua_util.round(score, 2) @@ -174,7 +174,7 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores) end -local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold) +local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold) local new_symbol_scores = weights:clone() @@ -210,7 +210,7 @@ end -- training function local function train(dataset, opt, model, criterion, epoch, - all_symbols) + all_symbols, spam_threshold) -- epoch tracker epoch = epoch or 1 @@ -284,9 +284,10 @@ local function train(dataset, opt, model, criterion, epoch, -- update confusion for i = 1,(last - t + 1) do - local class_predicted = 0 - if outputs[i][1] > 0.5 then class_predicted = 1 end - confusion:add(class_predicted + 1, targets[i] + 1) + local class_predicted, target_class = 1, 1 + if outputs[i][1] > 0.5 then class_predicted = 2 end + if targets[i] > 0.5 then target_class = 2 end + confusion:add(class_predicted, target_class) end -- return f and df/dX @@ -395,16 +396,16 @@ local function get_threshold() local actions = rspamd_config:get_all_actions() if opts['spam-action'] then - return actions[opts['spam-action']] or 0 - else - return actions['add header'] or actions['rewrite subject'] or actions['reject'] + return (actions[opts['spam-action']] or 0),actions['reject'] end + return (actions['add header'] or actions['rewrite subject'] + or actions['reject']), actions['reject'] end return function (args, cfg) opts = default_opts override_defaults(opts, getopt.getopt(args, 'i:')) - local threshold = get_threshold() + local threshold,reject_score = get_threshold() local logs = rescore_utility.get_all_logs(cfg["logdir"]) if opts['ignore-symbol'] then @@ -466,22 +467,22 @@ return function (args, cfg) local train_logs, validation_logs = split_logs(logs, 70) local cv_logs, test_logs = split_logs(validation_logs, 50) - local dataset = make_dataset_from_logs(train_logs, all_symbols) + local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score) -- Start of perceptron training local input_size = #all_symbols torch.setnumthreads(opts['threads']) - local linear_module = nn.Linear(input_size, 1) - local activation = nn.Tanh() + local linear_module = nn.Linear(input_size, 1, false) + local activation = nn.Sigmoid() local perceptron = nn.Sequential() perceptron:add(linear_module) perceptron:add(activation) local criterion = nn.MSECriterion() - criterion.sizeAverage = false + --criterion.sizeAverage = false local best_fscore = -math.huge local best_weights = linear_module.weight[1]:clone() @@ -494,13 +495,12 @@ return function (args, cfg) opts.learning_rate = lr opts.weight_decay = wd for i=1,tonumber(opts.iters) do - train(dataset, opts, perceptron, criterion, i, all_symbols) + train(dataset, opts, perceptron, criterion, i, all_symbols, threshold) end local fscore = calculate_fscore_from_weights(cv_logs, all_symbols, linear_module.weight[1], - linear_module.bias[1], threshold) logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s", |