aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-30 14:54:41 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-30 14:54:41 +0100
commit5d0d46d81568bd1f84b488a230360f8cc0a6e01a (patch)
treebfb118fcea9bbb487a25468dbe18a4bc16eb45ce /lualib/rspamadm
parent7c15db236e7511e957c540968b0e205c2b1d2b95 (diff)
downloadrspamd-5d0d46d81568bd1f84b488a230360f8cc0a6e01a.tar.gz
rspamd-5d0d46d81568bd1f84b488a230360f8cc0a6e01a.zip
[Minor] Further fixes to rescore tool
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/rescore.lua54
1 files changed, 33 insertions, 21 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index 80b9630f4..cc331c6e8 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -188,17 +188,18 @@ local function init_weights(all_symbols, original_symbol_scores)
return weights
end
-local function shuffle(logs)
+local function shuffle(logs, messages)
local size = #logs
for i = size, 1, -1 do
local rand = math.random(size)
logs[i], logs[rand] = logs[rand], logs[i]
+ messages[i], messages[rand] = messages[rand], messages[i]
end
end
-local function split_logs(logs, split_percent)
+local function split_logs(logs, messages, split_percent)
if not split_percent then
split_percent = 60
@@ -208,16 +209,20 @@ local function split_logs(logs, split_percent)
local test_logs = {}
local train_logs = {}
+ local test_messages = {}
+ local train_messages = {}
for i=1,split_index do
- train_logs[#train_logs + 1] = logs[i]
+ table.insert(train_logs, logs[i])
+ table.insert(train_messages, messages[i])
end
for i=split_index + 1, #logs do
- test_logs[#test_logs + 1] = logs[i]
+ table.insert(test_logs, logs[i])
+ table.insert(test_messages, messages[i])
end
- return train_logs, test_logs
+ return {train_logs,train_messages}, {test_logs,test_messages}
end
local function stitch_new_scores(all_symbols, new_scores)
@@ -291,7 +296,10 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores)
end
-local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold)
+local function calculate_fscore_from_weights(logs, messages,
+ all_symbols,
+ weights,
+ threshold)
local new_symbol_scores = weights:clone()
@@ -300,14 +308,15 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
logs = update_logs(logs, new_symbol_scores)
local file_stats, _, all_fps, all_fns =
- rescore_utility.generate_statistics_from_logs(logs, threshold)
+ rescore_utility.generate_statistics_from_logs(logs, messages, threshold)
return file_stats.fscore, all_fps, all_fns
end
-local function print_stats(logs, threshold)
+local function print_stats(logs, messages, threshold)
- local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs,
+ messages, threshold)
local file_stat_format = [[
F-score: %.2f
@@ -519,7 +528,7 @@ local function handler(args)
end
local threshold,reject_score = get_threshold()
- local logs = rescore_utility.get_all_logs(opts['log'])
+ local logs,messages = rescore_utility.get_all_logs(opts['log'])
if opts['ignore-symbol'] then
local function add_ignore(s)
@@ -574,7 +583,9 @@ local function handler(args)
-- Display hit frequencies
if opts['freq'] then
- local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs,
+ messages,
+ threshold)
local t = {}
for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
@@ -607,7 +618,7 @@ local function handler(args)
end
-- Print file statistics
- print_stats(logs, threshold)
+ print_stats(logs, messages, threshold)
-- Work out how many symbols weren't seen in the corpus
local symbols_no_hits = {}
@@ -635,13 +646,13 @@ local function handler(args)
return
end
- shuffle(logs)
+ shuffle(logs, messages)
torch.setdefaulttensortype('torch.FloatTensor')
- local train_logs, validation_logs = split_logs(logs, 70)
- local cv_logs, test_logs = split_logs(validation_logs, 50)
+ local train_logs, validation_logs = split_logs(logs, messages,70)
+ local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50)
- local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score)
+ local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score)
-- Start of perceptron training
local input_size = #all_symbols
@@ -675,7 +686,8 @@ local function handler(args)
initial_weights)
end
- local fscore, fps, fns = calculate_fscore_from_weights(cv_logs,
+ local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1],
+ cv_logs[2],
all_symbols,
linear_module.weight[1],
threshold)
@@ -710,13 +722,13 @@ local function handler(args)
-- Pre-rescore test stats
logger.message("\n\nPre-rescore test stats\n")
- test_logs = update_logs(test_logs, original_symbol_scores)
- print_stats(test_logs, threshold)
+ test_logs[1] = update_logs(test_logs[1], original_symbol_scores)
+ print_stats(test_logs[1], test_logs[2], threshold)
-- Post-rescore test stats
- test_logs = update_logs(test_logs, new_symbol_scores)
+ test_logs[1] = update_logs(test_logs[1], new_symbol_scores)
logger.message("\n\nPost-rescore test stats\n")
- print_stats(test_logs, threshold)
+ print_stats(test_logs[1], test_logs[2], threshold)
logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
best_fscore, best_learning_rate, best_weight_decay)