diff options
author | Pragadeesh C <cpragadeesh@gmail.com> | 2017-06-01 16:07:28 -0700 |
---|---|---|
committer | Pragadeesh C <cpragadeesh@gmail.com> | 2017-12-07 21:52:58 +0530 |
commit | 703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4 (patch) | |
tree | 1b3d94a09c1a8fbb24eb9d40e73a510bb8ed53d0 /lualib | |
parent | 40759556db803500b4eaaaf443dff0a94320f209 (diff) | |
download | rspamd-703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4.tar.gz rspamd-703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4.zip |
added corpus_test, rescore commands
Diffstat (limited to 'lualib')
-rw-r--r-- | lualib/rspamadm/corpus_test.lua | 126 | ||||
-rw-r--r-- | lualib/rspamadm/rescore.lua | 298 | ||||
-rw-r--r-- | lualib/rspamadm/rescore_utility.lua | 214 |
3 files changed, 638 insertions, 0 deletions
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua new file mode 100644 index 000000000..b29fa5602 --- /dev/null +++ b/lualib/rspamadm/corpus_test.lua @@ -0,0 +1,126 @@ +local ucl = require "ucl" +local lua_util = require "lua_util" + +local HAM = "HAM" +local SPAM = "SPAM" + +local function scan_email(n_parellel, path) + + local rspamc_command = string.format("rspamc -j --compact -n %s %s", n_parellel, path) + local result = assert(io.popen(rspamc_command)) + result = result:read("*all") + return result +end + +local function write_results(results, file) + + local f = io.open(file, 'w') + + for _, result in pairs(results) do + local log_line = string.format("%s %.2f %s", result.type, result.score, result.action) + + for _, sym in pairs(result.symbols) do + log_line = log_line .. " " .. sym + end + + log_line = log_line .. "\r\n" + + f:write(log_line) + end + + f:close() +end + +local function encoded_json_to_log(result) + -- Returns table containing score, action, list of symbols + + local filtered_result = {} + local parser = ucl.parser() + + local is_good, err = parser:parse_string(result) + + if not is_good then + print(err) + os.exit() + end + + result = parser:get_object() + + filtered_result.score = result.score + local action = result.action:gsub("%s+", "_") + filtered_result.action = action + + filtered_result.symbols = {} + + for sym, _ in pairs(result.symbols) do + table.insert(filtered_result.symbols, sym) + end + + return filtered_result +end + +local function scan_results_to_logs(results, actual_email_type) + + local logs = {} + + results = lua_util.rspamd_str_split(results, "\n") + + if results[#results] == "" then + results[#results] = nil + end + + for _, result in pairs(results) do + result = encoded_json_to_log(result) + result['type'] = actual_email_type + table.insert(logs, result) + end + + return logs +end + +return function (_, res) + + local ham_directory = res['ham_directory'] + local spam_directory = res['spam_directory'] + local connections = res["connections"] + local output = res["output_location"] + + local results = {} + + local start_time = os.time() + local no_of_ham = 0 + local no_of_spam = 0 + + if ham_directory then + io.write("Scanning ham corpus...\n") + local ham_results = scan_email(connections, ham_directory) + ham_results = scan_results_to_logs(ham_results, HAM) + + no_of_ham = #ham_results + + for _, result in pairs(ham_results) do + table.insert(results, result) + end + end + + if spam_directory then + io.write("Scanning spam corpus...\n") + local spam_results = scan_email(connections, spam_directory) + spam_results = scan_results_to_logs(spam_results, SPAM) + + no_of_spam = #spam_results + + for _, result in pairs(spam_results) do + table.insert(results, result) + end + end + + io.write(string.format("Writing results to %s\n", output)) + write_results(results, output) + + io.write("\nStats: \n") + io.write(string.format("Elapsed time: %ds\n", os.time() - start_time)) + io.write(string.format("No of ham: %d\n", no_of_ham)) + io.write(string.format("No of spam: %d\n", no_of_spam)) + +end
\ No newline at end of file diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua new file mode 100644 index 000000000..538122f68 --- /dev/null +++ b/lualib/rspamadm/rescore.lua @@ -0,0 +1,298 @@ +local torch = require "torch" +local nn = require "nn" +local lua_util = require "lua_util" +local ucl = require "ucl" + +local rescore_utility = require "rspamadm/rescore_utility" + +local function make_dataset_from_logs(logs, all_symbols) + -- Returns a list of {input, output} for torch SGD train + + local dataset = {} + + for _, log in pairs(logs) do + local input = torch.Tensor(#all_symbols) + local output = torch.Tensor(1) + log = lua_util.rspamd_str_split(log, " ") + + if log[1] == "SPAM" then + output[1] = 1 + else + output[1] = 0 + end + + local symbols_set = {} + + for i=4,#log do + symbols_set[log[i]] = true + end + + for index, symbol in pairs(all_symbols) do + if symbols_set[symbol] then + input[index] = 1 + else + input[index] = 0 + end + end + + dataset[#dataset + 1] = {input, output} + + end + + function dataset:size() + return #dataset + end + + return dataset +end + +local function init_weights(all_symbols, original_symbol_scores) + + local weights = torch.Tensor(#all_symbols) + + local mean = 0 + + for i, symbol in pairs(all_symbols) do + local score = original_symbol_scores[symbol] + if not score then score = 0 end + weights[i] = score + mean = mean + score + end + + return weights +end + +local function shuffle(logs) + + local size = #logs + for i = size, 1, -1 do + local rand = math.random(size) + logs[i], logs[rand] = logs[rand], logs[i] + end + +end + +local function split_logs(logs, split_percent) + + if not split_percent then + split_percent = 60 + end + + local split_index = math.floor(#logs * split_percent / 100) + + local test_logs = {} + local train_logs = {} + + for i=1,split_index do + train_logs[#train_logs + 1] = logs[i] + end + + for i=split_index + 1, #logs do + test_logs[#test_logs + 1] = logs[i] + end + + return train_logs, test_logs +end + +local function stitch_new_scores(all_symbols, new_scores) + + local new_symbol_scores = {} + + for idx, symbol in pairs(all_symbols) do + new_symbol_scores[symbol] = new_scores[idx] + end + + return new_symbol_scores +end + + +local function update_logs(logs, symbol_scores) + + for i, log in ipairs(logs) do + + log = lua_util.rspamd_str_split(log, " ") + + local score = 0 + + for j=4,#log do + log[j] = log[j]:gsub("%s+", "") + score = score + (symbol_scores[log[j ]] or 0) + end + + log[2] = rescore_utility.round(score, 2) + + logs[i] = table.concat(log, " ") + end + + return logs +end + +local function write_scores(new_symbol_scores, file_path) + + local file = assert(io.open(file_path, "w")) + + local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl") + + file:write(new_scores_ucl) + + file:close() +end + +local function print_score_diff(new_symbol_scores, original_symbol_scores) + + print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE")) + + for symbol, new_score in pairs(new_symbol_scores) do + print(string.format("%-35s %-10s %-10s", + symbol, + original_symbol_scores[symbol] or 0, + rescore_utility.round(new_score, 2))) + end + + print "\nClass changes \n" + for symbol, new_score in pairs(new_symbol_scores) do + if original_symbol_scores[symbol] ~= nil then + if (original_symbol_scores[symbol] > 0 and new_score < 0) or + (original_symbol_scores[symbol] < 0 and new_score > 0) then + print(string.format("%-35s %-10s %-10s", + symbol, + original_symbol_scores[symbol] or 0, + rescore_utility.round(new_score, 2))) + end + end + end + +end + +local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold) + + local new_symbol_scores = weights:clone() + + new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) + + logs = update_logs(logs, new_symbol_scores) + + local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + + return file_stats.fscore +end + +local function print_stats(logs, threshold) + + local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + + local file_stat_format = [[ +F-score: %.2f +False positive rate: %.2f %% +False negative rate: %.2f %% +Overall accuracy: %.2f %% +]] + + io.write("\nStatistics at threshold: " .. threshold .. "\n") + + io.write(string.format(file_stat_format, + file_stats.fscore, + file_stats.false_positive_rate, + file_stats.false_negative_rate, + file_stats.overall_accuracy)) + +end + +return function (_, res) + + local logs = rescore_utility.get_all_logs(res["logdir"]) + local all_symbols = rescore_utility.get_all_symbols(logs) + local original_symbol_scores = rescore_utility.get_all_symbol_scores() + + shuffle(logs) + + local train_logs, validation_logs = split_logs(logs, 70) + local cv_logs, test_logs = split_logs(validation_logs, 50) + + local dataset = make_dataset_from_logs(train_logs, all_symbols) + + local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10} + local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100} + + -- Start of perceptron training + + local input_size = #all_symbols + local linear_module = nn.Linear(input_size, 1) + + local perceptron = nn.Sequential() + perceptron:add(linear_module) + + local activation = nn.Sigmoid() + + perceptron:add(activation) + + local criterion = nn.MSECriterion() + criterion.sizeAverage = false + + local best_fscore = -math.huge + local best_weights = linear_module.weight[1]:clone() + + local trainer = nn.StochasticGradient(perceptron, criterion) + trainer.maxIteration = res["iters"] + trainer.verbose = false + + trainer.hookIteration = function(self, iteration, error) + + if iteration == trainer.maxIteration then + + local fscore = calculate_fscore_from_weights(cv_logs, + all_symbols, + linear_module.weight[1], + linear_module.bias[1], + res["threshold"]) + + print("Cross-validation fscore: " .. fscore) + + if best_fscore < fscore then + best_fscore = fscore + best_weights = linear_module.weight[1]:clone() + end + end + end + + for _, learning_rate in pairs(learning_rates) do + for _, weight in pairs(penalty_weights) do + + trainer.weightDecay = weight + print("Learning with learning_rate: " .. learning_rate + .. " | l2_weight: " .. weight) + + linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) + + trainer.learningRate = learning_rate + trainer:train(dataset) + + print() + end + end + + -- End perceptron training + + local new_symbol_scores = best_weights + + new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) + + if res["output"] then + write_scores(new_symbol_scores, res["output"]) + end + + if res["diff"] then + print_score_diff(new_symbol_scores, original_symbol_scores) + end + + + -- Pre-rescore test stats + print("\n\nPre-rescore test stats\n") + test_logs = update_logs(test_logs, original_symbol_scores) + print_stats(test_logs, res['threshold']) + + -- Post-rescore test stats + test_logs = update_logs(test_logs, new_symbol_scores) + print("\n\nPost-rescore test stats\n") + print_stats(test_logs, res['threshold']) +end
\ No newline at end of file diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua new file mode 100644 index 000000000..4c6504e76 --- /dev/null +++ b/lualib/rspamadm/rescore_utility.lua @@ -0,0 +1,214 @@ +local ucl = require "ucl" +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" + +local utility = {} + +function utility.round(num, places) + return string.format("%." .. (places or 0) .. "f", num) +end + +function utility.get_all_symbols(logs) + -- Returns a list of all symbols + + local symbols_set = {} + + for _, line in pairs(logs) do + line = lua_util.rspamd_str_split(line, " ") + for i=4,#line do + line[i] = line[i]:gsub("%s+", "") + if not symbols_set[line[i]] then + symbols_set[line[i]] = true + end + end + end + + local all_symbols = {} + + for symbol, _ in pairs(symbols_set) do + all_symbols[#all_symbols + 1] = symbol + end + + table.sort(all_symbols) + + return all_symbols +end + +function utility.read_log_file(file) + + local lines = {} + + file = assert(io.open(file, "r")) + + for line in file:lines() do + lines[#lines + 1] = line + end + + io.close(file) + + return lines +end + +function utility.get_all_logs(dir_path) + -- Reads all log files in the directory and returns a list of logs. + + if dir_path:sub(#dir_path, #dir_path) == "/" then + dir_path = dir_path:sub(1, #dir_path -1) + end + + local files = rspamd_util.glob(dir_path .. "/*") + local all_logs = {} + + for _, file in pairs(files) do + local logs = utility.read_log_file(file) + for _, log_line in pairs(logs) do + all_logs[#all_logs + 1] = log_line + end + end + + return all_logs +end + +function utility.get_all_symbol_scores() + + local output = assert(io.popen("rspamc counters -j --compact")) + output = output:read("*all") + + local parser = ucl.parser() + local result, err = parser:parse_string(output) + + if not result then + print(err) + os.exit() + end + + output = parser:get_object() + + local symbol_scores = {} + + for _, symbol_info in pairs(output) do + symbol_scores[symbol_info.symbol] = symbol_info.weight + end + + return symbol_scores +end + +function utility.generate_statistics_from_logs(logs, threshold) + + -- Returns file_stats table and list of symbol_stats table. + + local file_stats = { + no_of_emails = 0, + no_of_spam = 0, + no_of_ham = 0, + spam_percent = 0, + ham_percent = 0, + true_positives = 0, + true_negatives = 0, + false_negative_rate = 0, + false_positive_rate = 0, + overall_accuracy = 0, + fscore = 0 + } + + local all_symbols_stats = {} + + local false_positives = 0 + local false_negatives = 0 + local true_positives = 0 + local true_negatives = 0 + local no_of_emails = 0 + local no_of_spam = 0 + local no_of_ham = 0 + + for _, log in pairs(logs) do + log = lua_util.rspamd_str_trim(log) + log = lua_util.rspamd_str_split(log, " ") + + local is_spam = (log[1] == "SPAM") + local score = tonumber(log[2]) + + no_of_emails = no_of_emails + 1 + + if is_spam then + no_of_spam = no_of_spam + 1 + else + no_of_ham = no_of_ham + 1 + end + + if is_spam and (score >= threshold) then + true_positives = true_positives + 1 + elseif is_spam and (score < threshold) then + false_negatives = false_negatives + 1 + elseif not is_spam and (score >= threshold) then + false_positives = false_positives + 1 + else + true_negatives = true_negatives + 1 + end + + for i=4, #log do + if all_symbols_stats[log[i]] == nil then + all_symbols_stats[log[i]] = { + name = log[i], + no_of_hits = 0, + spam_hits = 0, + ham_hits = 0, + spam_overall = 0 + } + end + + all_symbols_stats[log[i]].no_of_hits = + all_symbols_stats[log[i]].no_of_hits + 1 + + if is_spam then + all_symbols_stats[log[i]].spam_hits = + all_symbols_stats[log[i]].spam_hits + 1 + else + all_symbols_stats[log[i]].ham_hits = + all_symbols_stats[log[i]].ham_hits + 1 + end + end + end + + -- Calculating file stats + + file_stats.no_of_ham = no_of_ham + file_stats.no_of_spam = no_of_spam + file_stats.no_of_emails = no_of_emails + file_stats.true_positives = true_positives + file_stats.true_negatives = true_negatives + + if no_of_emails > 0 then + file_stats.spam_percent = no_of_spam * 100 / no_of_emails + file_stats.ham_percent = no_of_ham * 100 / no_of_emails + file_stats.overall_accuracy = (true_positives + true_negatives) * 100 / + no_of_emails + end + + if no_of_ham > 0 then + file_stats.false_positive_rate = false_positives * 100 / no_of_ham + end + + if no_of_spam > 0 then + file_stats.false_negative_rate = false_negatives * 100 / no_of_spam + end + + file_stats.fscore = 2 * true_positives / (2 + * true_positives + + false_positives + + false_negatives) + + -- Calculating symbol stats + + for _, symbol_stats in pairs(all_symbols_stats) do + symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam + symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham + symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails + symbol_stats.spam_overall = symbol_stats.spam_percent / + (symbol_stats.spam_percent + symbol_stats.ham_percent) + end + + return file_stats, all_symbols_stats +end + +return utility |