diff options
Diffstat (limited to 'lualib')
-rw-r--r-- | lualib/rspamadm/rescore.lua | 95 | ||||
-rw-r--r-- | lualib/rspamadm/rescore_utility.lua | 23 |
2 files changed, 106 insertions, 12 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index c8348caa3..87e0ea2c5 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -182,9 +182,9 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho logs = update_logs(logs, new_symbol_scores) - local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold) - return file_stats.fscore + return file_stats.fscore, all_fps, all_fns end local function print_stats(logs, threshold) @@ -196,6 +196,7 @@ F-score: %.2f False positive rate: %.2f %% False negative rate: %.2f %% Overall accuracy: %.2f %% +Slowest message: %.2f (%s) ]] logger.message("\nStatistics at threshold: " .. threshold) @@ -204,7 +205,9 @@ Overall accuracy: %.2f %% file_stats.fscore, file_stats.false_positive_rate, file_stats.false_negative_rate, - file_stats.overall_accuracy)) + file_stats.overall_accuracy, + file_stats.slowest, + file_stats.slowest_file)) end @@ -463,6 +466,67 @@ return function (args, cfg) local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config, ignore_symbols) + -- Display hit frequencies + if opts['z'] then + local file_stats, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold) + local t = {} + for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end + function compare_symbols (a, b) + if (a.spam_overall ~= b.spam_overall) then + return b.spam_overall < a.spam_overall + end + if (b.spam_hits ~= a.spam_hits) then + return b.spam_hits < a.spam_hits + end + return b.ham_hits < a.ham_hits + end + table.sort(t, compare_symbols) + logger.message(string.format("%-40s %6s %6s %6s %6s %6s %6s %6s", + "NAME", "HITS", "HAM", "HAM%", "SPAM", "SPAM%", "S/O", "OVER%")) + for _, symbol_stats in pairs(t) do + logger.message( + string.format("%-40s %6d %6d %6.2f %6d %6.2f %6.2f %6.2f", + symbol_stats.name, + symbol_stats.no_of_hits, + symbol_stats.ham_hits, + lua_util.round(symbol_stats.ham_percent,2), + symbol_stats.spam_hits, + lua_util.round(symbol_stats.spam_percent,2), + lua_util.round(symbol_stats.spam_overall,2), + lua_util.round(symbol_stats.overall, 2) + ) + ) + end + + -- Print file statistics + print_stats(logs, threshold) + + -- Work out how many symbols weren't seen in the corpus + local symbols_no_hits = {} + local total_symbols = 0 + for sym in pairs(original_symbol_scores) do + total_symbols = total_symbols + 1 + if (all_symbols_stats[sym] == nil) then + table.insert(symbols_no_hits, sym) + end + end + if (#symbols_no_hits > 0) then + table.sort(symbols_no_hits) + -- Calculate percentage of rules with no hits + local nhpct = lua_util.round((#symbols_no_hits/total_symbols)*100,2) + logger.message( + string.format('\nFound %s (%-.2f%%) symbols out of %s with no hits in corpus:', + #symbols_no_hits, nhpct, total_symbols + ) + ) + for _, symbol in pairs(symbols_no_hits) do + logger.messagex('%s', symbol) + end + end + + return + end + shuffle(logs) torch.setdefaulttensortype('torch.FloatTensor') @@ -471,7 +535,6 @@ return function (args, cfg) local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score) - -- Start of perceptron training local input_size = #all_symbols torch.setnumthreads(opts['threads']) @@ -490,6 +553,8 @@ return function (args, cfg) local best_weights = linear_module.weight[1]:clone() local best_learning_rate local best_weight_decay + local all_fps + local all_fns for _,lr in ipairs(learning_rates) do for _,wd in ipairs(penalty_weights) do @@ -502,7 +567,7 @@ return function (args, cfg) initial_weights) end - local fscore = calculate_fscore_from_weights(cv_logs, + local fscore, fps, fns = calculate_fscore_from_weights(cv_logs, all_symbols, linear_module.weight[1], threshold) @@ -515,6 +580,8 @@ return function (args, cfg) best_weight_decay = wd best_fscore = fscore best_weights = linear_module.weight[1]:clone() + all_fps = fps + all_fns = fns end end end @@ -533,7 +600,6 @@ return function (args, cfg) print_score_diff(new_symbol_scores, original_symbol_scores) end - -- Pre-rescore test stats logger.message("\n\nPre-rescore test stats\n") test_logs = update_logs(test_logs, original_symbol_scores) @@ -546,4 +612,19 @@ return function (args, cfg) logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s', best_fscore, best_learning_rate, best_weight_decay) -end
\ No newline at end of file + + -- Show all FPs/FNs, useful for corpus checking and rule creation/modification + if (all_fps and #all_fps > 0) then + logger.message("\nFalse-Positives:") + for _, fp in pairs(all_fps) do + logger.messagex('%s', fp) + end + end + + if (all_fns and #all_fns > 0) then + logger.message("\nFalse-Negatives:") + for _, fn in pairs(all_fns) do + logger.messagex('%s', fn) + end + end +end diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua index 7f3f40078..2a9372d4e 100644 --- a/lualib/rspamadm/rescore_utility.lua +++ b/lualib/rspamadm/rescore_utility.lua @@ -11,7 +11,7 @@ function utility.get_all_symbols(logs, ignore_symbols) for _, line in pairs(logs) do line = lua_util.rspamd_str_split(line, " ") - for i=4,#line do + for i=4,(#line-2) do line[i] = line[i]:gsub("%s+", "") if not symbols_set[line[i]] then symbols_set[line[i]] = true @@ -54,7 +54,7 @@ function utility.get_all_logs(dir_path) dir_path = dir_path:sub(1, #dir_path -1) end - local files = rspamd_util.glob(dir_path .. "/*") + local files = rspamd_util.glob(dir_path .. "/*.log") local all_logs = {} for _, file in pairs(files) do @@ -92,10 +92,15 @@ function utility.generate_statistics_from_logs(logs, threshold) false_negative_rate = 0, false_positive_rate = 0, overall_accuracy = 0, - fscore = 0 + fscore = 0, + avg_scan_time = 0, + slowest_file = nil, + slowest = 0 } local all_symbols_stats = {} + local all_fps = {} + local all_fns = {} local false_positives = 0 local false_negatives = 0 @@ -124,13 +129,15 @@ function utility.generate_statistics_from_logs(logs, threshold) true_positives = true_positives + 1 elseif is_spam and (score < threshold) then false_negatives = false_negatives + 1 + table.insert(all_fns, log[#log]) elseif not is_spam and (score >= threshold) then false_positives = false_positives + 1 + table.insert(all_fps, log[#log]) else true_negatives = true_negatives + 1 end - for i=4, #log do + for i=4, (#log-2) do if all_symbols_stats[log[i]] == nil then all_symbols_stats[log[i]] = { name = log[i], @@ -151,6 +158,12 @@ function utility.generate_statistics_from_logs(logs, threshold) all_symbols_stats[log[i]].ham_hits = all_symbols_stats[log[i]].ham_hits + 1 end + + -- Find slowest message + if (tonumber(log[#log-1]) > tonumber(file_stats.slowest)) then + file_stats.slowest = tostring(tonumber(log[#log-1])) + file_stats.slowest_file = log[#log] + end end end @@ -192,7 +205,7 @@ function utility.generate_statistics_from_logs(logs, threshold) (symbol_stats.spam_percent + symbol_stats.ham_percent) end - return file_stats, all_symbols_stats + return file_stats, all_symbols_stats, all_fps, all_fns end return utility |