summaryrefslogtreecommitdiffstats
path: root/lualib
diff options
context:
space:
mode:
Diffstat (limited to 'lualib')
-rw-r--r--lualib/rspamadm/rescore.lua95
-rw-r--r--lualib/rspamadm/rescore_utility.lua23
2 files changed, 106 insertions, 12 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index c8348caa3..87e0ea2c5 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -182,9 +182,9 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
logs = update_logs(logs, new_symbol_scores)
- local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold)
- return file_stats.fscore
+ return file_stats.fscore, all_fps, all_fns
end
local function print_stats(logs, threshold)
@@ -196,6 +196,7 @@ F-score: %.2f
False positive rate: %.2f %%
False negative rate: %.2f %%
Overall accuracy: %.2f %%
+Slowest message: %.2f (%s)
]]
logger.message("\nStatistics at threshold: " .. threshold)
@@ -204,7 +205,9 @@ Overall accuracy: %.2f %%
file_stats.fscore,
file_stats.false_positive_rate,
file_stats.false_negative_rate,
- file_stats.overall_accuracy))
+ file_stats.overall_accuracy,
+ file_stats.slowest,
+ file_stats.slowest_file))
end
@@ -463,6 +466,67 @@ return function (args, cfg)
local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
ignore_symbols)
+ -- Display hit frequencies
+ if opts['z'] then
+ local file_stats, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local t = {}
+ for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
+ function compare_symbols (a, b)
+ if (a.spam_overall ~= b.spam_overall) then
+ return b.spam_overall < a.spam_overall
+ end
+ if (b.spam_hits ~= a.spam_hits) then
+ return b.spam_hits < a.spam_hits
+ end
+ return b.ham_hits < a.ham_hits
+ end
+ table.sort(t, compare_symbols)
+ logger.message(string.format("%-40s %6s %6s %6s %6s %6s %6s %6s",
+ "NAME", "HITS", "HAM", "HAM%", "SPAM", "SPAM%", "S/O", "OVER%"))
+ for _, symbol_stats in pairs(t) do
+ logger.message(
+ string.format("%-40s %6d %6d %6.2f %6d %6.2f %6.2f %6.2f",
+ symbol_stats.name,
+ symbol_stats.no_of_hits,
+ symbol_stats.ham_hits,
+ lua_util.round(symbol_stats.ham_percent,2),
+ symbol_stats.spam_hits,
+ lua_util.round(symbol_stats.spam_percent,2),
+ lua_util.round(symbol_stats.spam_overall,2),
+ lua_util.round(symbol_stats.overall, 2)
+ )
+ )
+ end
+
+ -- Print file statistics
+ print_stats(logs, threshold)
+
+ -- Work out how many symbols weren't seen in the corpus
+ local symbols_no_hits = {}
+ local total_symbols = 0
+ for sym in pairs(original_symbol_scores) do
+ total_symbols = total_symbols + 1
+ if (all_symbols_stats[sym] == nil) then
+ table.insert(symbols_no_hits, sym)
+ end
+ end
+ if (#symbols_no_hits > 0) then
+ table.sort(symbols_no_hits)
+ -- Calculate percentage of rules with no hits
+ local nhpct = lua_util.round((#symbols_no_hits/total_symbols)*100,2)
+ logger.message(
+ string.format('\nFound %s (%-.2f%%) symbols out of %s with no hits in corpus:',
+ #symbols_no_hits, nhpct, total_symbols
+ )
+ )
+ for _, symbol in pairs(symbols_no_hits) do
+ logger.messagex('%s', symbol)
+ end
+ end
+
+ return
+ end
+
shuffle(logs)
torch.setdefaulttensortype('torch.FloatTensor')
@@ -471,7 +535,6 @@ return function (args, cfg)
local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score)
-
-- Start of perceptron training
local input_size = #all_symbols
torch.setnumthreads(opts['threads'])
@@ -490,6 +553,8 @@ return function (args, cfg)
local best_weights = linear_module.weight[1]:clone()
local best_learning_rate
local best_weight_decay
+ local all_fps
+ local all_fns
for _,lr in ipairs(learning_rates) do
for _,wd in ipairs(penalty_weights) do
@@ -502,7 +567,7 @@ return function (args, cfg)
initial_weights)
end
- local fscore = calculate_fscore_from_weights(cv_logs,
+ local fscore, fps, fns = calculate_fscore_from_weights(cv_logs,
all_symbols,
linear_module.weight[1],
threshold)
@@ -515,6 +580,8 @@ return function (args, cfg)
best_weight_decay = wd
best_fscore = fscore
best_weights = linear_module.weight[1]:clone()
+ all_fps = fps
+ all_fns = fns
end
end
end
@@ -533,7 +600,6 @@ return function (args, cfg)
print_score_diff(new_symbol_scores, original_symbol_scores)
end
-
-- Pre-rescore test stats
logger.message("\n\nPre-rescore test stats\n")
test_logs = update_logs(test_logs, original_symbol_scores)
@@ -546,4 +612,19 @@ return function (args, cfg)
logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
best_fscore, best_learning_rate, best_weight_decay)
-end \ No newline at end of file
+
+ -- Show all FPs/FNs, useful for corpus checking and rule creation/modification
+ if (all_fps and #all_fps > 0) then
+ logger.message("\nFalse-Positives:")
+ for _, fp in pairs(all_fps) do
+ logger.messagex('%s', fp)
+ end
+ end
+
+ if (all_fns and #all_fns > 0) then
+ logger.message("\nFalse-Negatives:")
+ for _, fn in pairs(all_fns) do
+ logger.messagex('%s', fn)
+ end
+ end
+end
diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua
index 7f3f40078..2a9372d4e 100644
--- a/lualib/rspamadm/rescore_utility.lua
+++ b/lualib/rspamadm/rescore_utility.lua
@@ -11,7 +11,7 @@ function utility.get_all_symbols(logs, ignore_symbols)
for _, line in pairs(logs) do
line = lua_util.rspamd_str_split(line, " ")
- for i=4,#line do
+ for i=4,(#line-2) do
line[i] = line[i]:gsub("%s+", "")
if not symbols_set[line[i]] then
symbols_set[line[i]] = true
@@ -54,7 +54,7 @@ function utility.get_all_logs(dir_path)
dir_path = dir_path:sub(1, #dir_path -1)
end
- local files = rspamd_util.glob(dir_path .. "/*")
+ local files = rspamd_util.glob(dir_path .. "/*.log")
local all_logs = {}
for _, file in pairs(files) do
@@ -92,10 +92,15 @@ function utility.generate_statistics_from_logs(logs, threshold)
false_negative_rate = 0,
false_positive_rate = 0,
overall_accuracy = 0,
- fscore = 0
+ fscore = 0,
+ avg_scan_time = 0,
+ slowest_file = nil,
+ slowest = 0
}
local all_symbols_stats = {}
+ local all_fps = {}
+ local all_fns = {}
local false_positives = 0
local false_negatives = 0
@@ -124,13 +129,15 @@ function utility.generate_statistics_from_logs(logs, threshold)
true_positives = true_positives + 1
elseif is_spam and (score < threshold) then
false_negatives = false_negatives + 1
+ table.insert(all_fns, log[#log])
elseif not is_spam and (score >= threshold) then
false_positives = false_positives + 1
+ table.insert(all_fps, log[#log])
else
true_negatives = true_negatives + 1
end
- for i=4, #log do
+ for i=4, (#log-2) do
if all_symbols_stats[log[i]] == nil then
all_symbols_stats[log[i]] = {
name = log[i],
@@ -151,6 +158,12 @@ function utility.generate_statistics_from_logs(logs, threshold)
all_symbols_stats[log[i]].ham_hits =
all_symbols_stats[log[i]].ham_hits + 1
end
+
+ -- Find slowest message
+ if (tonumber(log[#log-1]) > tonumber(file_stats.slowest)) then
+ file_stats.slowest = tostring(tonumber(log[#log-1]))
+ file_stats.slowest_file = log[#log]
+ end
end
end
@@ -192,7 +205,7 @@ function utility.generate_statistics_from_logs(logs, threshold)
(symbol_stats.spam_percent + symbol_stats.ham_percent)
end
- return file_stats, all_symbols_stats
+ return file_stats, all_symbols_stats, all_fps, all_fns
end
return utility