From fcf72e80c786d109fc16d693ef80bb518ea24831 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 1 Mar 2018 16:00:39 +0000 Subject: [PATCH] [Rework] Rework rescore utility --- lualib/rspamadm/rescore.lua | 391 +++++++++++++++------------- lualib/rspamadm/rescore_utility.lua | 334 +++++++++++------------- src/rspamadm/configwizard.c | 2 +- src/rspamadm/rescore.c | 130 ++++++--- 4 files changed, 466 insertions(+), 391 deletions(-) diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index 4f6cc5075..3129af563 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -2,297 +2,338 @@ local torch = require "torch" local nn = require "nn" local lua_util = require "lua_util" local ucl = require "ucl" +local logger = require "rspamd_logger" +local getopt = require "rspamadm/getopt" local rescore_utility = require "rspamadm/rescore_utility" +local opts + local function make_dataset_from_logs(logs, all_symbols) - -- Returns a list of {input, output} for torch SGD train + -- Returns a list of {input, output} for torch SGD train - local dataset = {} + local dataset = {} - for _, log in pairs(logs) do - local input = torch.Tensor(#all_symbols) - local output = torch.Tensor(1) - log = lua_util.rspamd_str_split(log, " ") + for _, log in pairs(logs) do + local input = torch.Tensor(#all_symbols) + local output = torch.Tensor(1) + log = lua_util.rspamd_str_split(log, " ") - if log[1] == "SPAM" then - output[1] = 1 - else - output[1] = 0 - end + if log[1] == "SPAM" then + output[1] = 1 + else + output[1] = 0 + end - local symbols_set = {} + local symbols_set = {} - for i=4,#log do - symbols_set[log[i]] = true - end + for i=4,#log do + symbols_set[log[i]] = true + end - for index, symbol in pairs(all_symbols) do - if symbols_set[symbol] then - input[index] = 1 - else - input[index] = 0 - end - end + for index, symbol in pairs(all_symbols) do + if symbols_set[symbol] then + input[index] = 1 + else + input[index] = 0 + end + end - dataset[#dataset + 1] = {input, output} + dataset[#dataset + 1] = {input, output} - end + end - function dataset:size() - return #dataset - end + function dataset:size() + return #dataset + end - return dataset + return dataset end local function init_weights(all_symbols, original_symbol_scores) - local weights = torch.Tensor(#all_symbols) + local weights = torch.Tensor(#all_symbols) - local mean = 0 + local mean = 0 - for i, symbol in pairs(all_symbols) do - local score = original_symbol_scores[symbol] - if not score then score = 0 end - weights[i] = score - mean = mean + score - end + for i, symbol in pairs(all_symbols) do + local score = original_symbol_scores[symbol] + if not score then score = 0 end + weights[i] = score + mean = mean + score + end - return weights + return weights end local function shuffle(logs) - local size = #logs - for i = size, 1, -1 do - local rand = math.random(size) - logs[i], logs[rand] = logs[rand], logs[i] - end + local size = #logs + for i = size, 1, -1 do + local rand = math.random(size) + logs[i], logs[rand] = logs[rand], logs[i] + end end local function split_logs(logs, split_percent) - if not split_percent then - split_percent = 60 - end + if not split_percent then + split_percent = 60 + end - local split_index = math.floor(#logs * split_percent / 100) + local split_index = math.floor(#logs * split_percent / 100) - local test_logs = {} - local train_logs = {} + local test_logs = {} + local train_logs = {} - for i=1,split_index do - train_logs[#train_logs + 1] = logs[i] - end + for i=1,split_index do + train_logs[#train_logs + 1] = logs[i] + end - for i=split_index + 1, #logs do - test_logs[#test_logs + 1] = logs[i] - end + for i=split_index + 1, #logs do + test_logs[#test_logs + 1] = logs[i] + end - return train_logs, test_logs + return train_logs, test_logs end local function stitch_new_scores(all_symbols, new_scores) - local new_symbol_scores = {} + local new_symbol_scores = {} - for idx, symbol in pairs(all_symbols) do - new_symbol_scores[symbol] = new_scores[idx] - end + for idx, symbol in pairs(all_symbols) do + new_symbol_scores[symbol] = new_scores[idx] + end - return new_symbol_scores + return new_symbol_scores end local function update_logs(logs, symbol_scores) - for i, log in ipairs(logs) do + for i, log in ipairs(logs) do - log = lua_util.rspamd_str_split(log, " ") + log = lua_util.rspamd_str_split(log, " ") - local score = 0 + local score = 0 - for j=4,#log do - log[j] = log[j]:gsub("%s+", "") - score = score + (symbol_scores[log[j ]] or 0) - end + for j=4,#log do + log[j] = log[j]:gsub("%s+", "") + score = score + (symbol_scores[log[j ]] or 0) + end - log[2] = rescore_utility.round(score, 2) + log[2] = lua_util.round(score, 2) - logs[i] = table.concat(log, " ") - end + logs[i] = table.concat(log, " ") + end - return logs + return logs end local function write_scores(new_symbol_scores, file_path) - local file = assert(io.open(file_path, "w")) + local file = assert(io.open(file_path, "w")) - local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl") + local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl") - file:write(new_scores_ucl) + file:write(new_scores_ucl) - file:close() + file:close() end local function print_score_diff(new_symbol_scores, original_symbol_scores) - print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE")) - - for symbol, new_score in pairs(new_symbol_scores) do - print(string.format("%-35s %-10s %-10s", - symbol, - original_symbol_scores[symbol] or 0, - rescore_utility.round(new_score, 2))) - end - - print "\nClass changes \n" - for symbol, new_score in pairs(new_symbol_scores) do - if original_symbol_scores[symbol] ~= nil then - if (original_symbol_scores[symbol] > 0 and new_score < 0) or - (original_symbol_scores[symbol] < 0 and new_score > 0) then - print(string.format("%-35s %-10s %-10s", - symbol, - original_symbol_scores[symbol] or 0, - rescore_utility.round(new_score, 2))) - end - end - end + print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE")) + + for symbol, new_score in pairs(new_symbol_scores) do + print(string.format("%-35s %-10s %-10s", + symbol, + original_symbol_scores[symbol] or 0, + rescore_utility.round(new_score, 2))) + end + + print "\nClass changes \n" + for symbol, new_score in pairs(new_symbol_scores) do + if original_symbol_scores[symbol] ~= nil then + if (original_symbol_scores[symbol] > 0 and new_score < 0) or + (original_symbol_scores[symbol] < 0 and new_score > 0) then + print(string.format("%-35s %-10s %-10s", + symbol, + original_symbol_scores[symbol] or 0, + rescore_utility.round(new_score, 2))) + end + end + end end local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold) - local new_symbol_scores = weights:clone() + local new_symbol_scores = weights:clone() - new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) + new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) - logs = update_logs(logs, new_symbol_scores) + logs = update_logs(logs, new_symbol_scores) - local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) - return file_stats.fscore + return file_stats.fscore end local function print_stats(logs, threshold) - local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) + local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold) - local file_stat_format = [[ + local file_stat_format = [[ F-score: %.2f False positive rate: %.2f %% False negative rate: %.2f %% Overall accuracy: %.2f %% ]] - io.write("\nStatistics at threshold: " .. threshold .. "\n") + io.write("\nStatistics at threshold: " .. threshold .. "\n") + + io.write(string.format(file_stat_format, + file_stats.fscore, + file_stats.false_positive_rate, + file_stats.false_negative_rate, + file_stats.overall_accuracy)) - io.write(string.format(file_stat_format, - file_stats.fscore, - file_stats.false_positive_rate, - file_stats.false_negative_rate, - file_stats.overall_accuracy)) +end +local default_opts = { + verbose = true, + iters = 10, + threads = 1, +} + +local function override_defaults(def, override) + for k,v in pairs(override) do + if def[k] then + if type(v) == 'table' then + override_defaults(def[k], v) + else + def[k] = v + end + else + def[k] = v + end + end end -return function (_, res) +local function get_threshold(opts) + local actions = rspamd_config:get_all_actions() - local logs = rescore_utility.get_all_logs(res["logdir"]) - local all_symbols = rescore_utility.get_all_symbols(logs) - local original_symbol_scores = rescore_utility.get_all_symbol_scores(res["timeout"]) + if opts['spam-action'] then + return actions[opts['spam-action']] or 0 + else + return actions['add header'] or actions['rewrite subject'] or actions['reject'] + end +end - shuffle(logs) +return function (args, cfg) + opts = default_opts + override_defaults(opts, getopt.getopt(args, '')) + local threshold = get_threshold(opts) + local logs = rescore_utility.get_all_logs(cfg["logdir"]) + local all_symbols = rescore_utility.get_all_symbols(logs) + local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config) - local train_logs, validation_logs = split_logs(logs, 70) - local cv_logs, test_logs = split_logs(validation_logs, 50) + shuffle(logs) + torch.setdefaulttensortype('torch.FloatTensor') - local dataset = make_dataset_from_logs(train_logs, all_symbols) + local train_logs, validation_logs = split_logs(logs, 70) + local cv_logs, test_logs = split_logs(validation_logs, 50) - local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10} - local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100} + local dataset = make_dataset_from_logs(train_logs, all_symbols) - -- Start of perceptron training + local learning_rates = { + 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10 + } + local penalty_weights = { + 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100 + } - local input_size = #all_symbols - local linear_module = nn.Linear(input_size, 1) + -- Start of perceptron training - local perceptron = nn.Sequential() - perceptron:add(linear_module) + local input_size = #all_symbols + torch.setnumthreads(opts['threads']) + local linear_module = nn.Linear(input_size, 1) - local activation = nn.Sigmoid() + local perceptron = nn.Sequential() + perceptron:add(linear_module) - perceptron:add(activation) + local activation = nn.Sigmoid() - local criterion = nn.MSECriterion() - criterion.sizeAverage = false + perceptron:add(activation) - local best_fscore = -math.huge - local best_weights = linear_module.weight[1]:clone() + local criterion = nn.MSECriterion() + criterion.sizeAverage = false - local trainer = nn.StochasticGradient(perceptron, criterion) - trainer.maxIteration = res["iters"] - trainer.verbose = false + local best_fscore = -math.huge + local best_weights = linear_module.weight[1]:clone() - trainer.hookIteration = function(self, iteration, error) + local trainer = nn.StochasticGradient(perceptron, criterion) + trainer.maxIteration = opts["iters"] + trainer.verbose = opts['verbose'] + trainer.hookIteration = function(self, iteration, error) - if iteration == trainer.maxIteration then + if iteration == trainer.maxIteration then - local fscore = calculate_fscore_from_weights(cv_logs, - all_symbols, - linear_module.weight[1], - linear_module.bias[1], - res["threshold"]) + local fscore = calculate_fscore_from_weights(cv_logs, + all_symbols, + linear_module.weight[1], + linear_module.bias[1], + threshold) - print("Cross-validation fscore: " .. fscore) + print("Cross-validation fscore: " .. fscore) - if best_fscore < fscore then - best_fscore = fscore - best_weights = linear_module.weight[1]:clone() - end - end - end + if best_fscore < fscore then + best_fscore = fscore + best_weights = linear_module.weight[1]:clone() + end + end + end - for _, learning_rate in pairs(learning_rates) do - for _, weight in pairs(penalty_weights) do + for _, learning_rate in ipairs(learning_rates) do + for _, weight in ipairs(penalty_weights) do - trainer.weightDecay = weight - print("Learning with learning_rate: " .. learning_rate - .. " | l2_weight: " .. weight) + trainer.weightDecay = weight + print("Learning with learning_rate: " .. learning_rate + .. " | l2_weight: " .. weight) - linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) + linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) - trainer.learningRate = learning_rate - trainer:train(dataset) + trainer.learningRate = learning_rate + trainer:train(dataset) - print() - end - end + print() + end + end - -- End perceptron training + -- End perceptron training - local new_symbol_scores = best_weights + local new_symbol_scores = best_weights - new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) + new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) - if res["output"] then - write_scores(new_symbol_scores, res["output"]) - end + if res["output"] then + write_scores(new_symbol_scores, res["output"]) + end - if res["diff"] then - print_score_diff(new_symbol_scores, original_symbol_scores) - end + if opts["diff"] then + print_score_diff(new_symbol_scores, original_symbol_scores) + end - -- Pre-rescore test stats - print("\n\nPre-rescore test stats\n") - test_logs = update_logs(test_logs, original_symbol_scores) - print_stats(test_logs, res['threshold']) + -- Pre-rescore test stats + io.write("\n\nPre-rescore test stats\n\n") + test_logs = update_logs(test_logs, original_symbol_scores) + print_stats(test_logs, threshold) - -- Post-rescore test stats - test_logs = update_logs(test_logs, new_symbol_scores) - print("\n\nPost-rescore test stats\n") - print_stats(test_logs, res['threshold']) + -- Post-rescore test stats + test_logs = update_logs(test_logs, new_symbol_scores) + io.write("\n\nPost-rescore test stats\n\n") + print_stats(test_logs, threshold) end \ No newline at end of file diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua index 2390fc565..db79bbf7b 100644 --- a/lualib/rspamadm/rescore_utility.lua +++ b/lualib/rspamadm/rescore_utility.lua @@ -1,214 +1,194 @@ -local ucl = require "ucl" local lua_util = require "lua_util" local rspamd_util = require "rspamd_util" +local fun = require "fun" local utility = {} -function utility.round(num, places) - return string.format("%." .. (places or 0) .. "f", num) -end - function utility.get_all_symbols(logs) - -- Returns a list of all symbols + -- Returns a list of all symbols - local symbols_set = {} + local symbols_set = {} - for _, line in pairs(logs) do - line = lua_util.rspamd_str_split(line, " ") - for i=4,#line do - line[i] = line[i]:gsub("%s+", "") - if not symbols_set[line[i]] then - symbols_set[line[i]] = true - end - end - end + for _, line in pairs(logs) do + line = lua_util.rspamd_str_split(line, " ") + for i=4,#line do + line[i] = line[i]:gsub("%s+", "") + if not symbols_set[line[i]] then + symbols_set[line[i]] = true + end + end + end - local all_symbols = {} + local all_symbols = {} - for symbol, _ in pairs(symbols_set) do - all_symbols[#all_symbols + 1] = symbol - end + for symbol, _ in pairs(symbols_set) do + all_symbols[#all_symbols + 1] = symbol + end - table.sort(all_symbols) + table.sort(all_symbols) - return all_symbols + return all_symbols end function utility.read_log_file(file) - local lines = {} + local lines = {} - file = assert(io.open(file, "r")) + file = assert(io.open(file, "r")) - for line in file:lines() do - lines[#lines + 1] = line - end + for line in file:lines() do + lines[#lines + 1] = line + end - io.close(file) + io.close(file) - return lines + return lines end function utility.get_all_logs(dir_path) - -- Reads all log files in the directory and returns a list of logs. + -- Reads all log files in the directory and returns a list of logs. - if dir_path:sub(#dir_path, #dir_path) == "/" then - dir_path = dir_path:sub(1, #dir_path -1) - end + if dir_path:sub(#dir_path, #dir_path) == "/" then + dir_path = dir_path:sub(1, #dir_path -1) + end - local files = rspamd_util.glob(dir_path .. "/*") - local all_logs = {} + local files = rspamd_util.glob(dir_path .. "/*") + local all_logs = {} - for _, file in pairs(files) do - local logs = utility.read_log_file(file) - for _, log_line in pairs(logs) do - all_logs[#all_logs + 1] = log_line - end - end + for _, file in pairs(files) do + local logs = utility.read_log_file(file) + for _, log_line in pairs(logs) do + all_logs[#all_logs + 1] = log_line + end + end - return all_logs + return all_logs end -function utility.get_all_symbol_scores(timeout) - - local output = assert(io.popen("rspamc counters -j --compact -t " .. tostring(timeout))) - output = output:read("*all") - - local parser = ucl.parser() - local result, err = parser:parse_string(output) - - if not result then - print(err) - os.exit() - end - - output = parser:get_object() - - local symbol_scores = {} - - for _, symbol_info in pairs(output) do - symbol_scores[symbol_info.symbol] = symbol_info.weight - end +function utility.get_all_symbol_scores(conf) + local counters = conf:get_symbols_counters() - return symbol_scores + return fun.tomap(fun.map(function(elt) + return elt['symbol'],elt['weight'] + end, counters)) end function utility.generate_statistics_from_logs(logs, threshold) - -- Returns file_stats table and list of symbol_stats table. - - local file_stats = { - no_of_emails = 0, - no_of_spam = 0, - no_of_ham = 0, - spam_percent = 0, - ham_percent = 0, - true_positives = 0, - true_negatives = 0, - false_negative_rate = 0, - false_positive_rate = 0, - overall_accuracy = 0, - fscore = 0 - } - - local all_symbols_stats = {} - - local false_positives = 0 - local false_negatives = 0 - local true_positives = 0 - local true_negatives = 0 - local no_of_emails = 0 - local no_of_spam = 0 - local no_of_ham = 0 - - for _, log in pairs(logs) do - log = lua_util.rspamd_str_trim(log) - log = lua_util.rspamd_str_split(log, " ") - - local is_spam = (log[1] == "SPAM") - local score = tonumber(log[2]) - - no_of_emails = no_of_emails + 1 - - if is_spam then - no_of_spam = no_of_spam + 1 - else - no_of_ham = no_of_ham + 1 - end - - if is_spam and (score >= threshold) then - true_positives = true_positives + 1 - elseif is_spam and (score < threshold) then - false_negatives = false_negatives + 1 - elseif not is_spam and (score >= threshold) then - false_positives = false_positives + 1 - else - true_negatives = true_negatives + 1 - end - - for i=4, #log do - if all_symbols_stats[log[i]] == nil then - all_symbols_stats[log[i]] = { - name = log[i], - no_of_hits = 0, - spam_hits = 0, - ham_hits = 0, - spam_overall = 0 - } - end - - all_symbols_stats[log[i]].no_of_hits = - all_symbols_stats[log[i]].no_of_hits + 1 - - if is_spam then - all_symbols_stats[log[i]].spam_hits = - all_symbols_stats[log[i]].spam_hits + 1 - else - all_symbols_stats[log[i]].ham_hits = - all_symbols_stats[log[i]].ham_hits + 1 - end - end - end - - -- Calculating file stats - - file_stats.no_of_ham = no_of_ham - file_stats.no_of_spam = no_of_spam - file_stats.no_of_emails = no_of_emails - file_stats.true_positives = true_positives - file_stats.true_negatives = true_negatives - - if no_of_emails > 0 then - file_stats.spam_percent = no_of_spam * 100 / no_of_emails - file_stats.ham_percent = no_of_ham * 100 / no_of_emails - file_stats.overall_accuracy = (true_positives + true_negatives) * 100 / - no_of_emails - end - - if no_of_ham > 0 then - file_stats.false_positive_rate = false_positives * 100 / no_of_ham - end - - if no_of_spam > 0 then - file_stats.false_negative_rate = false_negatives * 100 / no_of_spam - end - - file_stats.fscore = 2 * true_positives / (2 - * true_positives - + false_positives - + false_negatives) - - -- Calculating symbol stats - - for _, symbol_stats in pairs(all_symbols_stats) do - symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam - symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham - symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails - symbol_stats.spam_overall = symbol_stats.spam_percent / - (symbol_stats.spam_percent + symbol_stats.ham_percent) - end - - return file_stats, all_symbols_stats + -- Returns file_stats table and list of symbol_stats table. + + local file_stats = { + no_of_emails = 0, + no_of_spam = 0, + no_of_ham = 0, + spam_percent = 0, + ham_percent = 0, + true_positives = 0, + true_negatives = 0, + false_negative_rate = 0, + false_positive_rate = 0, + overall_accuracy = 0, + fscore = 0 + } + + local all_symbols_stats = {} + + local false_positives = 0 + local false_negatives = 0 + local true_positives = 0 + local true_negatives = 0 + local no_of_emails = 0 + local no_of_spam = 0 + local no_of_ham = 0 + + for _, log in pairs(logs) do + log = lua_util.rspamd_str_trim(log) + log = lua_util.rspamd_str_split(log, " ") + + local is_spam = (log[1] == "SPAM") + local score = tonumber(log[2]) + + no_of_emails = no_of_emails + 1 + + if is_spam then + no_of_spam = no_of_spam + 1 + else + no_of_ham = no_of_ham + 1 + end + + if is_spam and (score >= threshold) then + true_positives = true_positives + 1 + elseif is_spam and (score < threshold) then + false_negatives = false_negatives + 1 + elseif not is_spam and (score >= threshold) then + false_positives = false_positives + 1 + else + true_negatives = true_negatives + 1 + end + + for i=4, #log do + if all_symbols_stats[log[i]] == nil then + all_symbols_stats[log[i]] = { + name = log[i], + no_of_hits = 0, + spam_hits = 0, + ham_hits = 0, + spam_overall = 0 + } + end + + all_symbols_stats[log[i]].no_of_hits = + all_symbols_stats[log[i]].no_of_hits + 1 + + if is_spam then + all_symbols_stats[log[i]].spam_hits = + all_symbols_stats[log[i]].spam_hits + 1 + else + all_symbols_stats[log[i]].ham_hits = + all_symbols_stats[log[i]].ham_hits + 1 + end + end + end + + -- Calculating file stats + + file_stats.no_of_ham = no_of_ham + file_stats.no_of_spam = no_of_spam + file_stats.no_of_emails = no_of_emails + file_stats.true_positives = true_positives + file_stats.true_negatives = true_negatives + + if no_of_emails > 0 then + file_stats.spam_percent = no_of_spam * 100 / no_of_emails + file_stats.ham_percent = no_of_ham * 100 / no_of_emails + file_stats.overall_accuracy = (true_positives + true_negatives) * 100 / + no_of_emails + end + + if no_of_ham > 0 then + file_stats.false_positive_rate = false_positives * 100 / no_of_ham + end + + if no_of_spam > 0 then + file_stats.false_negative_rate = false_negatives * 100 / no_of_spam + end + + file_stats.fscore = 2 * true_positives / (2 + * true_positives + + false_positives + + false_negatives) + + -- Calculating symbol stats + + for _, symbol_stats in pairs(all_symbols_stats) do + symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam + symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham + symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails + symbol_stats.spam_overall = symbol_stats.spam_percent / + (symbol_stats.spam_percent + symbol_stats.ham_percent) + end + + return file_stats, all_symbols_stats end return utility diff --git a/src/rspamadm/configwizard.c b/src/rspamadm/configwizard.c index bd8ccefbd..83ba980e0 100644 --- a/src/rspamadm/configwizard.c +++ b/src/rspamadm/configwizard.c @@ -137,7 +137,7 @@ rspamadm_configwizard (gint argc, gchar **argv) /* Do post-load actions */ rspamd_lua_post_load_config (cfg); - if (!rspamd_init_filters (rspamd_main->cfg, FALSE)) { + if (!rspamd_init_filters (cfg, FALSE)) { ret = FALSE; } diff --git a/src/rspamadm/rescore.c b/src/rspamadm/rescore.c index e0c69f13b..c88ca0350 100644 --- a/src/rspamadm/rescore.c +++ b/src/rspamadm/rescore.c @@ -26,12 +26,12 @@ static gchar *logdir = NULL; static gchar *output = "new.scores"; -static gdouble threshold = 15; /* Spam threshold */ static gboolean score_diff = false; /* Print score diff flag */ -static gint64 iters = 500; /* Perceptron max iterations */ -gdouble timeout = 60.0; - -/* TODO: think about adding the config file reading */ +static gchar *config = NULL; +extern struct rspamd_main *rspamd_main; +/* Defined in modules.c */ +extern module_t *modules[]; +extern worker_t *workers[]; static void rspamadm_rescore (gint argc, gchar **argv); @@ -51,13 +51,29 @@ static GOptionEntry entries[] = { "Scores output locaiton", NULL}, {"diff", 'd', 0, G_OPTION_ARG_NONE, &score_diff, "Print score diff", NULL}, - {"iters", 'i', 0, G_OPTION_ARG_INT64, &iters, - "Max iterations for perceptron [Default: 500]", NULL}, - {"timeout", 't', 0, G_OPTION_ARG_DOUBLE, &timeout, - "Timeout for connections [Default: 60]", NULL}, + {"config", 'c', 0, G_OPTION_ARG_STRING, &config, + "Config file to use", NULL}, {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL} }; +static void +config_logger (rspamd_mempool_t *pool, gpointer ud) +{ + struct rspamd_main *rm = ud; + + rm->cfg->log_type = RSPAMD_LOG_CONSOLE; + rm->cfg->log_level = G_LOG_LEVEL_CRITICAL; + + rspamd_set_logger (rm->cfg, g_quark_try_string ("main"), &rm->logger, + rm->server_pool); + + if (rspamd_log_open_priv (rm->logger, rm->workers_uid, rm->workers_gid) == + -1) { + fprintf (stderr, "Fatal error, cannot open logfile, exiting\n"); + exit (EXIT_FAILURE); + } +} + static const char * rspamadm_rescore_help (gboolean full_help) { @@ -70,8 +86,7 @@ rspamadm_rescore_help (gboolean full_help) { "-l: path to logs directory\n" "-o: scores output file location\n" "-d: print scores diff\n" - "-i: max iterations for perceptron\n" - "-t: timeout for rspamc operations (default: 60)\n"; + "-i: max iterations for perceptron\n"; } else { help_str = "Estimate optimal symbol weights from log files"; } @@ -85,7 +100,10 @@ rspamadm_rescore (gint argc, gchar **argv) { GOptionContext *context; GError *error = NULL; lua_State *L; - ucl_object_t *obj; + struct rspamd_config *cfg = rspamd_main->cfg, **pcfg; + gboolean ret = TRUE; + worker_t **pworker; + const gchar *confdir; context = g_option_context_new ( "rescore - estimate optimal symbol weights from log files"); @@ -116,30 +134,66 @@ rspamadm_rescore (gint argc, gchar **argv) { exit (EXIT_FAILURE); } - L = rspamd_lua_init (); - rspamd_lua_set_path (L, NULL, ucl_vars); - - obj = ucl_object_typed_new (UCL_OBJECT); - - ucl_object_insert_key (obj, ucl_object_fromstring (logdir), - "logdir", 0, false); - ucl_object_insert_key (obj, ucl_object_fromstring (output), - "output", 0, false); - ucl_object_insert_key (obj, ucl_object_fromdouble (threshold), - "threshold", 0, false); - ucl_object_insert_key (obj, ucl_object_fromint (iters), - "iters", 0, false); - ucl_object_insert_key (obj, ucl_object_frombool (score_diff), - "diff", 0, false); - ucl_object_insert_key (obj, ucl_object_fromdouble (timeout), - "timeout", 0, false); - - rspamadm_execute_lua_ucl_subr (L, - argc, - argv, - obj, - "rescore"); - - lua_close (L); - ucl_object_unref (obj); + if (config == NULL) { + if ((confdir = g_hash_table_lookup (ucl_vars, "CONFDIR")) == NULL) { + confdir = RSPAMD_CONFDIR; + } + + config = g_strdup_printf ("%s%c%s", confdir, G_DIR_SEPARATOR, + "rspamd.conf"); + } + + pworker = &workers[0]; + while (*pworker) { + /* Init string quarks */ + (void) g_quark_from_static_string ((*pworker)->name); + pworker++; + } + + cfg->cache = rspamd_symbols_cache_new (cfg); + cfg->compiled_modules = modules; + cfg->compiled_workers = workers; + cfg->cfg_name = config; + + if (!rspamd_config_read (cfg, cfg->cfg_name, NULL, + config_logger, rspamd_main, ucl_vars)) { + ret = FALSE; + } + else { + /* Do post-load actions */ + rspamd_lua_post_load_config (cfg); + + if (!rspamd_init_filters (cfg, FALSE)) { + ret = FALSE; + } + + if (ret) { + ret = rspamd_config_post_load (cfg, RSPAMD_CONFIG_INIT_SYMCACHE); + rspamd_symbols_cache_validate (cfg->cache, + cfg, + FALSE); + } + } + + if (ret) { + L = cfg->lua_state; + rspamd_lua_set_path (L, cfg->rcl_obj, ucl_vars); + ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (cfg->cfg_name), + "config_path", 0, false); + ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (logdir), + "logdir", 0, false); + ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (output), + "output", 0, false); + ucl_object_insert_key (cfg->rcl_obj, ucl_object_frombool (score_diff), + "diff", 0, false); + pcfg = lua_newuserdata (L, sizeof (struct rspamd_config *)); + rspamd_lua_setclass (L, "rspamd{config}", -1); + *pcfg = cfg; + lua_setglobal (L, "rspamd_config"); + rspamadm_execute_lua_ucl_subr (L, + argc, + argv, + cfg->rcl_obj, + "rescore"); + } } \ No newline at end of file -- 2.39.5