aboutsummaryrefslogtreecommitdiffstats
path: root/lualib
diff options
context:
space:
mode:
authorPragadeesh C <cpragadeesh@gmail.com>2017-06-01 16:07:28 -0700
committerPragadeesh C <cpragadeesh@gmail.com>2017-12-07 21:52:58 +0530
commit703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4 (patch)
tree1b3d94a09c1a8fbb24eb9d40e73a510bb8ed53d0 /lualib
parent40759556db803500b4eaaaf443dff0a94320f209 (diff)
downloadrspamd-703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4.tar.gz
rspamd-703bd13d5bedc30ed9bbeb7180d3cd083fc0e1f4.zip
added corpus_test, rescore commands
Diffstat (limited to 'lualib')
-rw-r--r--lualib/rspamadm/corpus_test.lua126
-rw-r--r--lualib/rspamadm/rescore.lua298
-rw-r--r--lualib/rspamadm/rescore_utility.lua214
3 files changed, 638 insertions, 0 deletions
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
new file mode 100644
index 000000000..b29fa5602
--- /dev/null
+++ b/lualib/rspamadm/corpus_test.lua
@@ -0,0 +1,126 @@
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+
+local function scan_email(n_parellel, path)
+
+ local rspamc_command = string.format("rspamc -j --compact -n %s %s", n_parellel, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
+
+local function write_results(results, file)
+
+ local f = io.open(file, 'w')
+
+ for _, result in pairs(results) do
+ local log_line = string.format("%s %.2f %s", result.type, result.score, result.action)
+
+ for _, sym in pairs(result.symbols) do
+ log_line = log_line .. " " .. sym
+ end
+
+ log_line = log_line .. "\r\n"
+
+ f:write(log_line)
+ end
+
+ f:close()
+end
+
+local function encoded_json_to_log(result)
+ -- Returns table containing score, action, list of symbols
+
+ local filtered_result = {}
+ local parser = ucl.parser()
+
+ local is_good, err = parser:parse_string(result)
+
+ if not is_good then
+ print(err)
+ os.exit()
+ end
+
+ result = parser:get_object()
+
+ filtered_result.score = result.score
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
+
+ filtered_result.symbols = {}
+
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
+
+ return filtered_result
+end
+
+local function scan_results_to_logs(results, actual_email_type)
+
+ local logs = {}
+
+ results = lua_util.rspamd_str_split(results, "\n")
+
+ if results[#results] == "" then
+ results[#results] = nil
+ end
+
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ result['type'] = actual_email_type
+ table.insert(logs, result)
+ end
+
+ return logs
+end
+
+return function (_, res)
+
+ local ham_directory = res['ham_directory']
+ local spam_directory = res['spam_directory']
+ local connections = res["connections"]
+ local output = res["output_location"]
+
+ local results = {}
+
+ local start_time = os.time()
+ local no_of_ham = 0
+ local no_of_spam = 0
+
+ if ham_directory then
+ io.write("Scanning ham corpus...\n")
+ local ham_results = scan_email(connections, ham_directory)
+ ham_results = scan_results_to_logs(ham_results, HAM)
+
+ no_of_ham = #ham_results
+
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
+ end
+ end
+
+ if spam_directory then
+ io.write("Scanning spam corpus...\n")
+ local spam_results = scan_email(connections, spam_directory)
+ spam_results = scan_results_to_logs(spam_results, SPAM)
+
+ no_of_spam = #spam_results
+
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
+ end
+ end
+
+ io.write(string.format("Writing results to %s\n", output))
+ write_results(results, output)
+
+ io.write("\nStats: \n")
+ io.write(string.format("Elapsed time: %ds\n", os.time() - start_time))
+ io.write(string.format("No of ham: %d\n", no_of_ham))
+ io.write(string.format("No of spam: %d\n", no_of_spam))
+
+end \ No newline at end of file
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
new file mode 100644
index 000000000..538122f68
--- /dev/null
+++ b/lualib/rspamadm/rescore.lua
@@ -0,0 +1,298 @@
+local torch = require "torch"
+local nn = require "nn"
+local lua_util = require "lua_util"
+local ucl = require "ucl"
+
+local rescore_utility = require "rspamadm/rescore_utility"
+
+local function make_dataset_from_logs(logs, all_symbols)
+ -- Returns a list of {input, output} for torch SGD train
+
+ local dataset = {}
+
+ for _, log in pairs(logs) do
+ local input = torch.Tensor(#all_symbols)
+ local output = torch.Tensor(1)
+ log = lua_util.rspamd_str_split(log, " ")
+
+ if log[1] == "SPAM" then
+ output[1] = 1
+ else
+ output[1] = 0
+ end
+
+ local symbols_set = {}
+
+ for i=4,#log do
+ symbols_set[log[i]] = true
+ end
+
+ for index, symbol in pairs(all_symbols) do
+ if symbols_set[symbol] then
+ input[index] = 1
+ else
+ input[index] = 0
+ end
+ end
+
+ dataset[#dataset + 1] = {input, output}
+
+ end
+
+ function dataset:size()
+ return #dataset
+ end
+
+ return dataset
+end
+
+local function init_weights(all_symbols, original_symbol_scores)
+
+ local weights = torch.Tensor(#all_symbols)
+
+ local mean = 0
+
+ for i, symbol in pairs(all_symbols) do
+ local score = original_symbol_scores[symbol]
+ if not score then score = 0 end
+ weights[i] = score
+ mean = mean + score
+ end
+
+ return weights
+end
+
+local function shuffle(logs)
+
+ local size = #logs
+ for i = size, 1, -1 do
+ local rand = math.random(size)
+ logs[i], logs[rand] = logs[rand], logs[i]
+ end
+
+end
+
+local function split_logs(logs, split_percent)
+
+ if not split_percent then
+ split_percent = 60
+ end
+
+ local split_index = math.floor(#logs * split_percent / 100)
+
+ local test_logs = {}
+ local train_logs = {}
+
+ for i=1,split_index do
+ train_logs[#train_logs + 1] = logs[i]
+ end
+
+ for i=split_index + 1, #logs do
+ test_logs[#test_logs + 1] = logs[i]
+ end
+
+ return train_logs, test_logs
+end
+
+local function stitch_new_scores(all_symbols, new_scores)
+
+ local new_symbol_scores = {}
+
+ for idx, symbol in pairs(all_symbols) do
+ new_symbol_scores[symbol] = new_scores[idx]
+ end
+
+ return new_symbol_scores
+end
+
+
+local function update_logs(logs, symbol_scores)
+
+ for i, log in ipairs(logs) do
+
+ log = lua_util.rspamd_str_split(log, " ")
+
+ local score = 0
+
+ for j=4,#log do
+ log[j] = log[j]:gsub("%s+", "")
+ score = score + (symbol_scores[log[j ]] or 0)
+ end
+
+ log[2] = rescore_utility.round(score, 2)
+
+ logs[i] = table.concat(log, " ")
+ end
+
+ return logs
+end
+
+local function write_scores(new_symbol_scores, file_path)
+
+ local file = assert(io.open(file_path, "w"))
+
+ local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
+
+ file:write(new_scores_ucl)
+
+ file:close()
+end
+
+local function print_score_diff(new_symbol_scores, original_symbol_scores)
+
+ print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
+
+ for symbol, new_score in pairs(new_symbol_scores) do
+ print(string.format("%-35s %-10s %-10s",
+ symbol,
+ original_symbol_scores[symbol] or 0,
+ rescore_utility.round(new_score, 2)))
+ end
+
+ print "\nClass changes \n"
+ for symbol, new_score in pairs(new_symbol_scores) do
+ if original_symbol_scores[symbol] ~= nil then
+ if (original_symbol_scores[symbol] > 0 and new_score < 0) or
+ (original_symbol_scores[symbol] < 0 and new_score > 0) then
+ print(string.format("%-35s %-10s %-10s",
+ symbol,
+ original_symbol_scores[symbol] or 0,
+ rescore_utility.round(new_score, 2)))
+ end
+ end
+ end
+
+end
+
+local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold)
+
+ local new_symbol_scores = weights:clone()
+
+ new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+
+ logs = update_logs(logs, new_symbol_scores)
+
+ local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+
+ return file_stats.fscore
+end
+
+local function print_stats(logs, threshold)
+
+ local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+
+ local file_stat_format = [[
+F-score: %.2f
+False positive rate: %.2f %%
+False negative rate: %.2f %%
+Overall accuracy: %.2f %%
+]]
+
+ io.write("\nStatistics at threshold: " .. threshold .. "\n")
+
+ io.write(string.format(file_stat_format,
+ file_stats.fscore,
+ file_stats.false_positive_rate,
+ file_stats.false_negative_rate,
+ file_stats.overall_accuracy))
+
+end
+
+return function (_, res)
+
+ local logs = rescore_utility.get_all_logs(res["logdir"])
+ local all_symbols = rescore_utility.get_all_symbols(logs)
+ local original_symbol_scores = rescore_utility.get_all_symbol_scores()
+
+ shuffle(logs)
+
+ local train_logs, validation_logs = split_logs(logs, 70)
+ local cv_logs, test_logs = split_logs(validation_logs, 50)
+
+ local dataset = make_dataset_from_logs(train_logs, all_symbols)
+
+ local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10}
+ local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100}
+
+ -- Start of perceptron training
+
+ local input_size = #all_symbols
+ local linear_module = nn.Linear(input_size, 1)
+
+ local perceptron = nn.Sequential()
+ perceptron:add(linear_module)
+
+ local activation = nn.Sigmoid()
+
+ perceptron:add(activation)
+
+ local criterion = nn.MSECriterion()
+ criterion.sizeAverage = false
+
+ local best_fscore = -math.huge
+ local best_weights = linear_module.weight[1]:clone()
+
+ local trainer = nn.StochasticGradient(perceptron, criterion)
+ trainer.maxIteration = res["iters"]
+ trainer.verbose = false
+
+ trainer.hookIteration = function(self, iteration, error)
+
+ if iteration == trainer.maxIteration then
+
+ local fscore = calculate_fscore_from_weights(cv_logs,
+ all_symbols,
+ linear_module.weight[1],
+ linear_module.bias[1],
+ res["threshold"])
+
+ print("Cross-validation fscore: " .. fscore)
+
+ if best_fscore < fscore then
+ best_fscore = fscore
+ best_weights = linear_module.weight[1]:clone()
+ end
+ end
+ end
+
+ for _, learning_rate in pairs(learning_rates) do
+ for _, weight in pairs(penalty_weights) do
+
+ trainer.weightDecay = weight
+ print("Learning with learning_rate: " .. learning_rate
+ .. " | l2_weight: " .. weight)
+
+ linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+
+ trainer.learningRate = learning_rate
+ trainer:train(dataset)
+
+ print()
+ end
+ end
+
+ -- End perceptron training
+
+ local new_symbol_scores = best_weights
+
+ new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+
+ if res["output"] then
+ write_scores(new_symbol_scores, res["output"])
+ end
+
+ if res["diff"] then
+ print_score_diff(new_symbol_scores, original_symbol_scores)
+ end
+
+
+ -- Pre-rescore test stats
+ print("\n\nPre-rescore test stats\n")
+ test_logs = update_logs(test_logs, original_symbol_scores)
+ print_stats(test_logs, res['threshold'])
+
+ -- Post-rescore test stats
+ test_logs = update_logs(test_logs, new_symbol_scores)
+ print("\n\nPost-rescore test stats\n")
+ print_stats(test_logs, res['threshold'])
+end \ No newline at end of file
diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua
new file mode 100644
index 000000000..4c6504e76
--- /dev/null
+++ b/lualib/rspamadm/rescore_utility.lua
@@ -0,0 +1,214 @@
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+
+local utility = {}
+
+function utility.round(num, places)
+ return string.format("%." .. (places or 0) .. "f", num)
+end
+
+function utility.get_all_symbols(logs)
+ -- Returns a list of all symbols
+
+ local symbols_set = {}
+
+ for _, line in pairs(logs) do
+ line = lua_util.rspamd_str_split(line, " ")
+ for i=4,#line do
+ line[i] = line[i]:gsub("%s+", "")
+ if not symbols_set[line[i]] then
+ symbols_set[line[i]] = true
+ end
+ end
+ end
+
+ local all_symbols = {}
+
+ for symbol, _ in pairs(symbols_set) do
+ all_symbols[#all_symbols + 1] = symbol
+ end
+
+ table.sort(all_symbols)
+
+ return all_symbols
+end
+
+function utility.read_log_file(file)
+
+ local lines = {}
+
+ file = assert(io.open(file, "r"))
+
+ for line in file:lines() do
+ lines[#lines + 1] = line
+ end
+
+ io.close(file)
+
+ return lines
+end
+
+function utility.get_all_logs(dir_path)
+ -- Reads all log files in the directory and returns a list of logs.
+
+ if dir_path:sub(#dir_path, #dir_path) == "/" then
+ dir_path = dir_path:sub(1, #dir_path -1)
+ end
+
+ local files = rspamd_util.glob(dir_path .. "/*")
+ local all_logs = {}
+
+ for _, file in pairs(files) do
+ local logs = utility.read_log_file(file)
+ for _, log_line in pairs(logs) do
+ all_logs[#all_logs + 1] = log_line
+ end
+ end
+
+ return all_logs
+end
+
+function utility.get_all_symbol_scores()
+
+ local output = assert(io.popen("rspamc counters -j --compact"))
+ output = output:read("*all")
+
+ local parser = ucl.parser()
+ local result, err = parser:parse_string(output)
+
+ if not result then
+ print(err)
+ os.exit()
+ end
+
+ output = parser:get_object()
+
+ local symbol_scores = {}
+
+ for _, symbol_info in pairs(output) do
+ symbol_scores[symbol_info.symbol] = symbol_info.weight
+ end
+
+ return symbol_scores
+end
+
+function utility.generate_statistics_from_logs(logs, threshold)
+
+ -- Returns file_stats table and list of symbol_stats table.
+
+ local file_stats = {
+ no_of_emails = 0,
+ no_of_spam = 0,
+ no_of_ham = 0,
+ spam_percent = 0,
+ ham_percent = 0,
+ true_positives = 0,
+ true_negatives = 0,
+ false_negative_rate = 0,
+ false_positive_rate = 0,
+ overall_accuracy = 0,
+ fscore = 0
+ }
+
+ local all_symbols_stats = {}
+
+ local false_positives = 0
+ local false_negatives = 0
+ local true_positives = 0
+ local true_negatives = 0
+ local no_of_emails = 0
+ local no_of_spam = 0
+ local no_of_ham = 0
+
+ for _, log in pairs(logs) do
+ log = lua_util.rspamd_str_trim(log)
+ log = lua_util.rspamd_str_split(log, " ")
+
+ local is_spam = (log[1] == "SPAM")
+ local score = tonumber(log[2])
+
+ no_of_emails = no_of_emails + 1
+
+ if is_spam then
+ no_of_spam = no_of_spam + 1
+ else
+ no_of_ham = no_of_ham + 1
+ end
+
+ if is_spam and (score >= threshold) then
+ true_positives = true_positives + 1
+ elseif is_spam and (score < threshold) then
+ false_negatives = false_negatives + 1
+ elseif not is_spam and (score >= threshold) then
+ false_positives = false_positives + 1
+ else
+ true_negatives = true_negatives + 1
+ end
+
+ for i=4, #log do
+ if all_symbols_stats[log[i]] == nil then
+ all_symbols_stats[log[i]] = {
+ name = log[i],
+ no_of_hits = 0,
+ spam_hits = 0,
+ ham_hits = 0,
+ spam_overall = 0
+ }
+ end
+
+ all_symbols_stats[log[i]].no_of_hits =
+ all_symbols_stats[log[i]].no_of_hits + 1
+
+ if is_spam then
+ all_symbols_stats[log[i]].spam_hits =
+ all_symbols_stats[log[i]].spam_hits + 1
+ else
+ all_symbols_stats[log[i]].ham_hits =
+ all_symbols_stats[log[i]].ham_hits + 1
+ end
+ end
+ end
+
+ -- Calculating file stats
+
+ file_stats.no_of_ham = no_of_ham
+ file_stats.no_of_spam = no_of_spam
+ file_stats.no_of_emails = no_of_emails
+ file_stats.true_positives = true_positives
+ file_stats.true_negatives = true_negatives
+
+ if no_of_emails > 0 then
+ file_stats.spam_percent = no_of_spam * 100 / no_of_emails
+ file_stats.ham_percent = no_of_ham * 100 / no_of_emails
+ file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
+ no_of_emails
+ end
+
+ if no_of_ham > 0 then
+ file_stats.false_positive_rate = false_positives * 100 / no_of_ham
+ end
+
+ if no_of_spam > 0 then
+ file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
+ end
+
+ file_stats.fscore = 2 * true_positives / (2
+ * true_positives
+ + false_positives
+ + false_negatives)
+
+ -- Calculating symbol stats
+
+ for _, symbol_stats in pairs(all_symbols_stats) do
+ symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
+ symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
+ symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
+ symbol_stats.spam_overall = symbol_stats.spam_percent /
+ (symbol_stats.spam_percent + symbol_stats.ham_percent)
+ end
+
+ return file_stats, all_symbols_stats
+end
+
+return utility