diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-30 14:54:41 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-30 14:54:41 +0100 |
commit | 5d0d46d81568bd1f84b488a230360f8cc0a6e01a (patch) | |
tree | bfb118fcea9bbb487a25468dbe18a4bc16eb45ce | |
parent | 7c15db236e7511e957c540968b0e205c2b1d2b95 (diff) | |
download | rspamd-5d0d46d81568bd1f84b488a230360f8cc0a6e01a.tar.gz rspamd-5d0d46d81568bd1f84b488a230360f8cc0a6e01a.zip |
[Minor] Further fixes to rescore tool
-rw-r--r-- | lualib/rescore_utility.lua | 71 | ||||
-rw-r--r-- | lualib/rspamadm/rescore.lua | 54 |
2 files changed, 73 insertions, 52 deletions
diff --git a/lualib/rescore_utility.lua b/lualib/rescore_utility.lua index 39ad63365..195ae4364 100644 --- a/lualib/rescore_utility.lua +++ b/lualib/rescore_utility.lua @@ -11,7 +11,7 @@ function utility.get_all_symbols(logs, ignore_symbols) for _, line in pairs(logs) do line = lua_util.rspamd_str_split(line, " ") - for i=4,(#line-2) do + for i=4,(#line-1) do line[i] = line[i]:gsub("%s+", "") if not symbols_set[line[i]] then symbols_set[line[i]] = true @@ -35,16 +35,23 @@ end function utility.read_log_file(file) local lines = {} + local messages = {} - file = assert(io.open(file, "r")) + local fd = assert(io.open(file, "r")) + local fname = string.gsub(file, "(.*/)(.*)", "%2") - for line in file:lines() do - lines[#lines + 1] = line - end + for line in fd:lines() do + local start,stop = string.find(line, fname .. ':') + + if start and stop then + table.insert(lines, string.sub(line, 1, start)) + table.insert(messages, string.sub(line, stop + 1, -1)) + end +end - io.close(file) + io.close(fd) - return lines + return lines,messages end function utility.get_all_logs(dirs) @@ -55,26 +62,29 @@ function utility.get_all_logs(dirs) end local all_logs = {} + local all_messages = {} for _,dir in ipairs(dirs) do if dir:sub(-1, -1) == "/" then dir = dir:sub(1, -2) local files = rspamd_util.glob(dir .. "/*.log") for _, file in pairs(files) do - local logs = utility.read_log_file(file) - for _, log_line in pairs(logs) do - table.insert(all_logs, log_line) + local logs,messages = utility.read_log_file(file) + for i=1,#logs do + table.insert(all_logs, logs[i]) + table.insert(all_messages, messages[i]) end end else - local logs = utility.read_log_file(dir) - for _, log_line in pairs(logs) do - table.insert(all_logs, log_line) + local logs,messages = utility.read_log_file(dir) + for i=1,#logs do + table.insert(all_logs, logs[i]) + table.insert(all_messages, messages[i]) end end end - return all_logs + return all_logs,all_messages end function utility.get_all_symbol_scores(conf, ignore_symbols) @@ -87,7 +97,7 @@ function utility.get_all_symbol_scores(conf, ignore_symbols) end, symbols))) end -function utility.generate_statistics_from_logs(logs, threshold) +function utility.generate_statistics_from_logs(logs, messages, threshold) -- Returns file_stats table and list of symbol_stats table. @@ -120,9 +130,10 @@ function utility.generate_statistics_from_logs(logs, threshold) local no_of_spam = 0 local no_of_ham = 0 - for _, log in pairs(logs) do + for i, log in ipairs(logs) do log = lua_util.rspamd_str_trim(log) log = lua_util.rspamd_str_split(log, " ") + local message = messages[i] local is_spam = (log[1] == "SPAM") local score = tonumber(log[2]) @@ -139,40 +150,38 @@ function utility.generate_statistics_from_logs(logs, threshold) true_positives = true_positives + 1 elseif is_spam and (score < threshold) then false_negatives = false_negatives + 1 - table.insert(all_fns, log[#log]) + table.insert(all_fns, message) elseif not is_spam and (score >= threshold) then false_positives = false_positives + 1 - table.insert(all_fps, log[#log]) + table.insert(all_fps, message) else true_negatives = true_negatives + 1 end - for i=4, (#log-2) do - if all_symbols_stats[log[i]] == nil then - all_symbols_stats[log[i]] = { - name = log[i], + for j=4, (#log-1) do + if all_symbols_stats[log[j]] == nil then + all_symbols_stats[log[j]] = { + name = message, no_of_hits = 0, spam_hits = 0, ham_hits = 0, spam_overall = 0 } end + local sym = log[j] - all_symbols_stats[log[i]].no_of_hits = - all_symbols_stats[log[i]].no_of_hits + 1 + all_symbols_stats[sym].no_of_hits = all_symbols_stats[sym].no_of_hits + 1 if is_spam then - all_symbols_stats[log[i]].spam_hits = - all_symbols_stats[log[i]].spam_hits + 1 + all_symbols_stats[sym].spam_hits = all_symbols_stats[sym].spam_hits + 1 else - all_symbols_stats[log[i]].ham_hits = - all_symbols_stats[log[i]].ham_hits + 1 + all_symbols_stats[sym].ham_hits = all_symbols_stats[sym].ham_hits + 1 end -- Find slowest message - if ((tonumber(log[#log-1]) or 0) > file_stats.slowest) then - file_stats.slowest = tonumber(log[#log-1]) - file_stats.slowest_file = log[#log] + if ((tonumber(log[#log]) or 0) > file_stats.slowest) then + file_stats.slowest = tonumber(log[#log]) + file_stats.slowest_file = message end end end diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index 80b9630f4..cc331c6e8 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -188,17 +188,18 @@ local function init_weights(all_symbols, original_symbol_scores) return weights end -local function shuffle(logs) +local function shuffle(logs, messages) local size = #logs for i = size, 1, -1 do local rand = math.random(size) logs[i], logs[rand] = logs[rand], logs[i] + messages[i], messages[rand] = messages[rand], messages[i] end end -local function split_logs(logs, split_percent) +local function split_logs(logs, messages, split_percent) if not split_percent then split_percent = 60 @@ -208,16 +209,20 @@ local function split_logs(logs, split_percent) local test_logs = {} local train_logs = {} + local test_messages = {} + local train_messages = {} for i=1,split_index do - train_logs[#train_logs + 1] = logs[i] + table.insert(train_logs, logs[i]) + table.insert(train_messages, messages[i]) end for i=split_index + 1, #logs do - test_logs[#test_logs + 1] = logs[i] + table.insert(test_logs, logs[i]) + table.insert(test_messages, messages[i]) end - return train_logs, test_logs + return {train_logs,train_messages}, {test_logs,test_messages} end local function stitch_new_scores(all_symbols, new_scores) @@ -291,7 +296,10 @@ local function print_score_diff(new_symbol_scores, original_symbol_scores) end -local function calculate_fscore_from_weights(logs, all_symbols, weights, threshold) +local function calculate_fscore_from_weights(logs, messages, + all_symbols, + weights, + threshold) local new_symbol_scores = weights:clone() @@ -300,14 +308,15 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho logs = update_logs(logs, new_symbol_scores) local file_stats, _, all_fps, all_fns = - rescore_utility.generate_statistics_from_logs(logs, threshold) + rescore_utility.generate_statistics_from_logs(logs, messages, threshold) return file_stats.fscore, all_fps, all_fns end -local function print_stats(logs, threshold) +local function print_stats(logs, messages, threshold) - local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, + messages, threshold) local file_stat_format = [[ F-score: %.2f @@ -519,7 +528,7 @@ local function handler(args) end local threshold,reject_score = get_threshold() - local logs = rescore_utility.get_all_logs(opts['log']) + local logs,messages = rescore_utility.get_all_logs(opts['log']) if opts['ignore-symbol'] then local function add_ignore(s) @@ -574,7 +583,9 @@ local function handler(args) -- Display hit frequencies if opts['freq'] then - local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold) + local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, + messages, + threshold) local t = {} for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end @@ -607,7 +618,7 @@ local function handler(args) end -- Print file statistics - print_stats(logs, threshold) + print_stats(logs, messages, threshold) -- Work out how many symbols weren't seen in the corpus local symbols_no_hits = {} @@ -635,13 +646,13 @@ local function handler(args) return end - shuffle(logs) + shuffle(logs, messages) torch.setdefaulttensortype('torch.FloatTensor') - local train_logs, validation_logs = split_logs(logs, 70) - local cv_logs, test_logs = split_logs(validation_logs, 50) + local train_logs, validation_logs = split_logs(logs, messages,70) + local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50) - local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score) + local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score) -- Start of perceptron training local input_size = #all_symbols @@ -675,7 +686,8 @@ local function handler(args) initial_weights) end - local fscore, fps, fns = calculate_fscore_from_weights(cv_logs, + local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1], + cv_logs[2], all_symbols, linear_module.weight[1], threshold) @@ -710,13 +722,13 @@ local function handler(args) -- Pre-rescore test stats logger.message("\n\nPre-rescore test stats\n") - test_logs = update_logs(test_logs, original_symbol_scores) - print_stats(test_logs, threshold) + test_logs[1] = update_logs(test_logs[1], original_symbol_scores) + print_stats(test_logs[1], test_logs[2], threshold) -- Post-rescore test stats - test_logs = update_logs(test_logs, new_symbol_scores) + test_logs[1] = update_logs(test_logs[1], new_symbol_scores) logger.message("\n\nPost-rescore test stats\n") - print_stats(test_logs, threshold) + print_stats(test_logs[1], test_logs[2], threshold) logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s', best_fscore, best_learning_rate, best_weight_decay) |