local nn = require "nn"
local lua_util = require "lua_util"
local ucl = require "ucl"
+local logger = require "rspamd_logger"
+local getopt = require "rspamadm/getopt"
local rescore_utility = require "rspamadm/rescore_utility"
+local opts
+
local function make_dataset_from_logs(logs, all_symbols)
- -- Returns a list of {input, output} for torch SGD train
+ -- Returns a list of {input, output} for torch SGD train
- local dataset = {}
+ local dataset = {}
- for _, log in pairs(logs) do
- local input = torch.Tensor(#all_symbols)
- local output = torch.Tensor(1)
- log = lua_util.rspamd_str_split(log, " ")
+ for _, log in pairs(logs) do
+ local input = torch.Tensor(#all_symbols)
+ local output = torch.Tensor(1)
+ log = lua_util.rspamd_str_split(log, " ")
- if log[1] == "SPAM" then
- output[1] = 1
- else
- output[1] = 0
- end
+ if log[1] == "SPAM" then
+ output[1] = 1
+ else
+ output[1] = 0
+ end
- local symbols_set = {}
+ local symbols_set = {}
- for i=4,#log do
- symbols_set[log[i]] = true
- end
+ for i=4,#log do
+ symbols_set[log[i]] = true
+ end
- for index, symbol in pairs(all_symbols) do
- if symbols_set[symbol] then
- input[index] = 1
- else
- input[index] = 0
- end
- end
+ for index, symbol in pairs(all_symbols) do
+ if symbols_set[symbol] then
+ input[index] = 1
+ else
+ input[index] = 0
+ end
+ end
- dataset[#dataset + 1] = {input, output}
+ dataset[#dataset + 1] = {input, output}
- end
+ end
- function dataset:size()
- return #dataset
- end
+ function dataset:size()
+ return #dataset
+ end
- return dataset
+ return dataset
end
local function init_weights(all_symbols, original_symbol_scores)
- local weights = torch.Tensor(#all_symbols)
+ local weights = torch.Tensor(#all_symbols)
- local mean = 0
+ local mean = 0
- for i, symbol in pairs(all_symbols) do
- local score = original_symbol_scores[symbol]
- if not score then score = 0 end
- weights[i] = score
- mean = mean + score
- end
+ for i, symbol in pairs(all_symbols) do
+ local score = original_symbol_scores[symbol]
+ if not score then score = 0 end
+ weights[i] = score
+ mean = mean + score
+ end
- return weights
+ return weights
end
local function shuffle(logs)
- local size = #logs
- for i = size, 1, -1 do
- local rand = math.random(size)
- logs[i], logs[rand] = logs[rand], logs[i]
- end
+ local size = #logs
+ for i = size, 1, -1 do
+ local rand = math.random(size)
+ logs[i], logs[rand] = logs[rand], logs[i]
+ end
end
local function split_logs(logs, split_percent)
- if not split_percent then
- split_percent = 60
- end
+ if not split_percent then
+ split_percent = 60
+ end
- local split_index = math.floor(#logs * split_percent / 100)
+ local split_index = math.floor(#logs * split_percent / 100)
- local test_logs = {}
- local train_logs = {}
+ local test_logs = {}
+ local train_logs = {}
- for i=1,split_index do
- train_logs[#train_logs + 1] = logs[i]
- end
+ for i=1,split_index do
+ train_logs[#train_logs + 1] = logs[i]
+ end
- for i=split_index + 1, #logs do
- test_logs[#test_logs + 1] = logs[i]
- end
+ for i=split_index + 1, #logs do
+ test_logs[#test_logs + 1] = logs[i]
+ end
- return train_logs, test_logs
+ return train_logs, test_logs
end
local function stitch_new_scores(all_symbols, new_scores)
- local new_symbol_scores = {}
+ local new_symbol_scores = {}
- for idx, symbol in pairs(all_symbols) do
- new_symbol_scores[symbol] = new_scores[idx]
- end
+ for idx, symbol in pairs(all_symbols) do
+ new_symbol_scores[symbol] = new_scores[idx]
+ end
- return new_symbol_scores
+ return new_symbol_scores
end
local function update_logs(logs, symbol_scores)
- for i, log in ipairs(logs) do
+ for i, log in ipairs(logs) do
- log = lua_util.rspamd_str_split(log, " ")
+ log = lua_util.rspamd_str_split(log, " ")
- local score = 0
+ local score = 0
- for j=4,#log do
- log[j] = log[j]:gsub("%s+", "")
- score = score + (symbol_scores[log[j ]] or 0)
- end
+ for j=4,#log do
+ log[j] = log[j]:gsub("%s+", "")
+ score = score + (symbol_scores[log[j ]] or 0)
+ end
- log[2] = rescore_utility.round(score, 2)
+ log[2] = lua_util.round(score, 2)
- logs[i] = table.concat(log, " ")
- end
+ logs[i] = table.concat(log, " ")
+ end
- return logs
+ return logs
end
local function write_scores(new_symbol_scores, file_path)
- local file = assert(io.open(file_path, "w"))
+ local file = assert(io.open(file_path, "w"))
- local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
+ local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
- file:write(new_scores_ucl)
+ file:write(new_scores_ucl)
- file:close()
+ file:close()
end
local function print_score_diff(new_symbol_scores, original_symbol_scores)
- print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
-
- for symbol, new_score in pairs(new_symbol_scores) do
- print(string.format("%-35s %-10s %-10s",
- symbol,
- original_symbol_scores[symbol] or 0,
- rescore_utility.round(new_score, 2)))
- end
-
- print "\nClass changes \n"
- for symbol, new_score in pairs(new_symbol_scores) do
- if original_symbol_scores[symbol] ~= nil then
- if (original_symbol_scores[symbol] > 0 and new_score < 0) or
- (original_symbol_scores[symbol] < 0 and new_score > 0) then
- print(string.format("%-35s %-10s %-10s",
- symbol,
- original_symbol_scores[symbol] or 0,
- rescore_utility.round(new_score, 2)))
- end
- end
- end
+ print(string.format("%-35s %-10s %-10s", "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
+
+ for symbol, new_score in pairs(new_symbol_scores) do
+ print(string.format("%-35s %-10s %-10s",
+ symbol,
+ original_symbol_scores[symbol] or 0,
+ rescore_utility.round(new_score, 2)))
+ end
+
+ print "\nClass changes \n"
+ for symbol, new_score in pairs(new_symbol_scores) do
+ if original_symbol_scores[symbol] ~= nil then
+ if (original_symbol_scores[symbol] > 0 and new_score < 0) or
+ (original_symbol_scores[symbol] < 0 and new_score > 0) then
+ print(string.format("%-35s %-10s %-10s",
+ symbol,
+ original_symbol_scores[symbol] or 0,
+ rescore_utility.round(new_score, 2)))
+ end
+ end
+ end
end
local function calculate_fscore_from_weights(logs, all_symbols, weights, bias, threshold)
- local new_symbol_scores = weights:clone()
+ local new_symbol_scores = weights:clone()
- new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+ new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
- logs = update_logs(logs, new_symbol_scores)
+ logs = update_logs(logs, new_symbol_scores)
- local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
- return file_stats.fscore
+ return file_stats.fscore
end
local function print_stats(logs, threshold)
- local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
- local file_stat_format = [[
+ local file_stat_format = [[
F-score: %.2f
False positive rate: %.2f %%
False negative rate: %.2f %%
Overall accuracy: %.2f %%
]]
- io.write("\nStatistics at threshold: " .. threshold .. "\n")
+ io.write("\nStatistics at threshold: " .. threshold .. "\n")
+
+ io.write(string.format(file_stat_format,
+ file_stats.fscore,
+ file_stats.false_positive_rate,
+ file_stats.false_negative_rate,
+ file_stats.overall_accuracy))
- io.write(string.format(file_stat_format,
- file_stats.fscore,
- file_stats.false_positive_rate,
- file_stats.false_negative_rate,
- file_stats.overall_accuracy))
+end
+local default_opts = {
+ verbose = true,
+ iters = 10,
+ threads = 1,
+}
+
+local function override_defaults(def, override)
+ for k,v in pairs(override) do
+ if def[k] then
+ if type(v) == 'table' then
+ override_defaults(def[k], v)
+ else
+ def[k] = v
+ end
+ else
+ def[k] = v
+ end
+ end
end
-return function (_, res)
+local function get_threshold(opts)
+ local actions = rspamd_config:get_all_actions()
- local logs = rescore_utility.get_all_logs(res["logdir"])
- local all_symbols = rescore_utility.get_all_symbols(logs)
- local original_symbol_scores = rescore_utility.get_all_symbol_scores(res["timeout"])
+ if opts['spam-action'] then
+ return actions[opts['spam-action']] or 0
+ else
+ return actions['add header'] or actions['rewrite subject'] or actions['reject']
+ end
+end
- shuffle(logs)
+return function (args, cfg)
+ opts = default_opts
+ override_defaults(opts, getopt.getopt(args, ''))
+ local threshold = get_threshold(opts)
+ local logs = rescore_utility.get_all_logs(cfg["logdir"])
+ local all_symbols = rescore_utility.get_all_symbols(logs)
+ local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config)
- local train_logs, validation_logs = split_logs(logs, 70)
- local cv_logs, test_logs = split_logs(validation_logs, 50)
+ shuffle(logs)
+ torch.setdefaulttensortype('torch.FloatTensor')
- local dataset = make_dataset_from_logs(train_logs, all_symbols)
+ local train_logs, validation_logs = split_logs(logs, 70)
+ local cv_logs, test_logs = split_logs(validation_logs, 50)
- local learning_rates = {0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10}
- local penalty_weights = {0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100}
+ local dataset = make_dataset_from_logs(train_logs, all_symbols)
- -- Start of perceptron training
+ local learning_rates = {
+ 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
+ }
+ local penalty_weights = {
+ 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
+ }
- local input_size = #all_symbols
- local linear_module = nn.Linear(input_size, 1)
+ -- Start of perceptron training
- local perceptron = nn.Sequential()
- perceptron:add(linear_module)
+ local input_size = #all_symbols
+ torch.setnumthreads(opts['threads'])
+ local linear_module = nn.Linear(input_size, 1)
- local activation = nn.Sigmoid()
+ local perceptron = nn.Sequential()
+ perceptron:add(linear_module)
- perceptron:add(activation)
+ local activation = nn.Sigmoid()
- local criterion = nn.MSECriterion()
- criterion.sizeAverage = false
+ perceptron:add(activation)
- local best_fscore = -math.huge
- local best_weights = linear_module.weight[1]:clone()
+ local criterion = nn.MSECriterion()
+ criterion.sizeAverage = false
- local trainer = nn.StochasticGradient(perceptron, criterion)
- trainer.maxIteration = res["iters"]
- trainer.verbose = false
+ local best_fscore = -math.huge
+ local best_weights = linear_module.weight[1]:clone()
- trainer.hookIteration = function(self, iteration, error)
+ local trainer = nn.StochasticGradient(perceptron, criterion)
+ trainer.maxIteration = opts["iters"]
+ trainer.verbose = opts['verbose']
+ trainer.hookIteration = function(self, iteration, error)
- if iteration == trainer.maxIteration then
+ if iteration == trainer.maxIteration then
- local fscore = calculate_fscore_from_weights(cv_logs,
- all_symbols,
- linear_module.weight[1],
- linear_module.bias[1],
- res["threshold"])
+ local fscore = calculate_fscore_from_weights(cv_logs,
+ all_symbols,
+ linear_module.weight[1],
+ linear_module.bias[1],
+ threshold)
- print("Cross-validation fscore: " .. fscore)
+ print("Cross-validation fscore: " .. fscore)
- if best_fscore < fscore then
- best_fscore = fscore
- best_weights = linear_module.weight[1]:clone()
- end
- end
- end
+ if best_fscore < fscore then
+ best_fscore = fscore
+ best_weights = linear_module.weight[1]:clone()
+ end
+ end
+ end
- for _, learning_rate in pairs(learning_rates) do
- for _, weight in pairs(penalty_weights) do
+ for _, learning_rate in ipairs(learning_rates) do
+ for _, weight in ipairs(penalty_weights) do
- trainer.weightDecay = weight
- print("Learning with learning_rate: " .. learning_rate
- .. " | l2_weight: " .. weight)
+ trainer.weightDecay = weight
+ print("Learning with learning_rate: " .. learning_rate
+ .. " | l2_weight: " .. weight)
- linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+ linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
- trainer.learningRate = learning_rate
- trainer:train(dataset)
+ trainer.learningRate = learning_rate
+ trainer:train(dataset)
- print()
- end
- end
+ print()
+ end
+ end
- -- End perceptron training
+ -- End perceptron training
- local new_symbol_scores = best_weights
+ local new_symbol_scores = best_weights
- new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
+ new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
- if res["output"] then
- write_scores(new_symbol_scores, res["output"])
- end
+ if res["output"] then
+ write_scores(new_symbol_scores, res["output"])
+ end
- if res["diff"] then
- print_score_diff(new_symbol_scores, original_symbol_scores)
- end
+ if opts["diff"] then
+ print_score_diff(new_symbol_scores, original_symbol_scores)
+ end
- -- Pre-rescore test stats
- print("\n\nPre-rescore test stats\n")
- test_logs = update_logs(test_logs, original_symbol_scores)
- print_stats(test_logs, res['threshold'])
+ -- Pre-rescore test stats
+ io.write("\n\nPre-rescore test stats\n\n")
+ test_logs = update_logs(test_logs, original_symbol_scores)
+ print_stats(test_logs, threshold)
- -- Post-rescore test stats
- test_logs = update_logs(test_logs, new_symbol_scores)
- print("\n\nPost-rescore test stats\n")
- print_stats(test_logs, res['threshold'])
+ -- Post-rescore test stats
+ test_logs = update_logs(test_logs, new_symbol_scores)
+ io.write("\n\nPost-rescore test stats\n\n")
+ print_stats(test_logs, threshold)
end
\ No newline at end of file
-local ucl = require "ucl"
local lua_util = require "lua_util"
local rspamd_util = require "rspamd_util"
+local fun = require "fun"
local utility = {}
-function utility.round(num, places)
- return string.format("%." .. (places or 0) .. "f", num)
-end
-
function utility.get_all_symbols(logs)
- -- Returns a list of all symbols
+ -- Returns a list of all symbols
- local symbols_set = {}
+ local symbols_set = {}
- for _, line in pairs(logs) do
- line = lua_util.rspamd_str_split(line, " ")
- for i=4,#line do
- line[i] = line[i]:gsub("%s+", "")
- if not symbols_set[line[i]] then
- symbols_set[line[i]] = true
- end
- end
- end
+ for _, line in pairs(logs) do
+ line = lua_util.rspamd_str_split(line, " ")
+ for i=4,#line do
+ line[i] = line[i]:gsub("%s+", "")
+ if not symbols_set[line[i]] then
+ symbols_set[line[i]] = true
+ end
+ end
+ end
- local all_symbols = {}
+ local all_symbols = {}
- for symbol, _ in pairs(symbols_set) do
- all_symbols[#all_symbols + 1] = symbol
- end
+ for symbol, _ in pairs(symbols_set) do
+ all_symbols[#all_symbols + 1] = symbol
+ end
- table.sort(all_symbols)
+ table.sort(all_symbols)
- return all_symbols
+ return all_symbols
end
function utility.read_log_file(file)
- local lines = {}
+ local lines = {}
- file = assert(io.open(file, "r"))
+ file = assert(io.open(file, "r"))
- for line in file:lines() do
- lines[#lines + 1] = line
- end
+ for line in file:lines() do
+ lines[#lines + 1] = line
+ end
- io.close(file)
+ io.close(file)
- return lines
+ return lines
end
function utility.get_all_logs(dir_path)
- -- Reads all log files in the directory and returns a list of logs.
+ -- Reads all log files in the directory and returns a list of logs.
- if dir_path:sub(#dir_path, #dir_path) == "/" then
- dir_path = dir_path:sub(1, #dir_path -1)
- end
+ if dir_path:sub(#dir_path, #dir_path) == "/" then
+ dir_path = dir_path:sub(1, #dir_path -1)
+ end
- local files = rspamd_util.glob(dir_path .. "/*")
- local all_logs = {}
+ local files = rspamd_util.glob(dir_path .. "/*")
+ local all_logs = {}
- for _, file in pairs(files) do
- local logs = utility.read_log_file(file)
- for _, log_line in pairs(logs) do
- all_logs[#all_logs + 1] = log_line
- end
- end
+ for _, file in pairs(files) do
+ local logs = utility.read_log_file(file)
+ for _, log_line in pairs(logs) do
+ all_logs[#all_logs + 1] = log_line
+ end
+ end
- return all_logs
+ return all_logs
end
-function utility.get_all_symbol_scores(timeout)
-
- local output = assert(io.popen("rspamc counters -j --compact -t " .. tostring(timeout)))
- output = output:read("*all")
-
- local parser = ucl.parser()
- local result, err = parser:parse_string(output)
-
- if not result then
- print(err)
- os.exit()
- end
-
- output = parser:get_object()
-
- local symbol_scores = {}
-
- for _, symbol_info in pairs(output) do
- symbol_scores[symbol_info.symbol] = symbol_info.weight
- end
+function utility.get_all_symbol_scores(conf)
+ local counters = conf:get_symbols_counters()
- return symbol_scores
+ return fun.tomap(fun.map(function(elt)
+ return elt['symbol'],elt['weight']
+ end, counters))
end
function utility.generate_statistics_from_logs(logs, threshold)
- -- Returns file_stats table and list of symbol_stats table.
-
- local file_stats = {
- no_of_emails = 0,
- no_of_spam = 0,
- no_of_ham = 0,
- spam_percent = 0,
- ham_percent = 0,
- true_positives = 0,
- true_negatives = 0,
- false_negative_rate = 0,
- false_positive_rate = 0,
- overall_accuracy = 0,
- fscore = 0
- }
-
- local all_symbols_stats = {}
-
- local false_positives = 0
- local false_negatives = 0
- local true_positives = 0
- local true_negatives = 0
- local no_of_emails = 0
- local no_of_spam = 0
- local no_of_ham = 0
-
- for _, log in pairs(logs) do
- log = lua_util.rspamd_str_trim(log)
- log = lua_util.rspamd_str_split(log, " ")
-
- local is_spam = (log[1] == "SPAM")
- local score = tonumber(log[2])
-
- no_of_emails = no_of_emails + 1
-
- if is_spam then
- no_of_spam = no_of_spam + 1
- else
- no_of_ham = no_of_ham + 1
- end
-
- if is_spam and (score >= threshold) then
- true_positives = true_positives + 1
- elseif is_spam and (score < threshold) then
- false_negatives = false_negatives + 1
- elseif not is_spam and (score >= threshold) then
- false_positives = false_positives + 1
- else
- true_negatives = true_negatives + 1
- end
-
- for i=4, #log do
- if all_symbols_stats[log[i]] == nil then
- all_symbols_stats[log[i]] = {
- name = log[i],
- no_of_hits = 0,
- spam_hits = 0,
- ham_hits = 0,
- spam_overall = 0
- }
- end
-
- all_symbols_stats[log[i]].no_of_hits =
- all_symbols_stats[log[i]].no_of_hits + 1
-
- if is_spam then
- all_symbols_stats[log[i]].spam_hits =
- all_symbols_stats[log[i]].spam_hits + 1
- else
- all_symbols_stats[log[i]].ham_hits =
- all_symbols_stats[log[i]].ham_hits + 1
- end
- end
- end
-
- -- Calculating file stats
-
- file_stats.no_of_ham = no_of_ham
- file_stats.no_of_spam = no_of_spam
- file_stats.no_of_emails = no_of_emails
- file_stats.true_positives = true_positives
- file_stats.true_negatives = true_negatives
-
- if no_of_emails > 0 then
- file_stats.spam_percent = no_of_spam * 100 / no_of_emails
- file_stats.ham_percent = no_of_ham * 100 / no_of_emails
- file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
- no_of_emails
- end
-
- if no_of_ham > 0 then
- file_stats.false_positive_rate = false_positives * 100 / no_of_ham
- end
-
- if no_of_spam > 0 then
- file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
- end
-
- file_stats.fscore = 2 * true_positives / (2
- * true_positives
- + false_positives
- + false_negatives)
-
- -- Calculating symbol stats
-
- for _, symbol_stats in pairs(all_symbols_stats) do
- symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
- symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
- symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
- symbol_stats.spam_overall = symbol_stats.spam_percent /
- (symbol_stats.spam_percent + symbol_stats.ham_percent)
- end
-
- return file_stats, all_symbols_stats
+ -- Returns file_stats table and list of symbol_stats table.
+
+ local file_stats = {
+ no_of_emails = 0,
+ no_of_spam = 0,
+ no_of_ham = 0,
+ spam_percent = 0,
+ ham_percent = 0,
+ true_positives = 0,
+ true_negatives = 0,
+ false_negative_rate = 0,
+ false_positive_rate = 0,
+ overall_accuracy = 0,
+ fscore = 0
+ }
+
+ local all_symbols_stats = {}
+
+ local false_positives = 0
+ local false_negatives = 0
+ local true_positives = 0
+ local true_negatives = 0
+ local no_of_emails = 0
+ local no_of_spam = 0
+ local no_of_ham = 0
+
+ for _, log in pairs(logs) do
+ log = lua_util.rspamd_str_trim(log)
+ log = lua_util.rspamd_str_split(log, " ")
+
+ local is_spam = (log[1] == "SPAM")
+ local score = tonumber(log[2])
+
+ no_of_emails = no_of_emails + 1
+
+ if is_spam then
+ no_of_spam = no_of_spam + 1
+ else
+ no_of_ham = no_of_ham + 1
+ end
+
+ if is_spam and (score >= threshold) then
+ true_positives = true_positives + 1
+ elseif is_spam and (score < threshold) then
+ false_negatives = false_negatives + 1
+ elseif not is_spam and (score >= threshold) then
+ false_positives = false_positives + 1
+ else
+ true_negatives = true_negatives + 1
+ end
+
+ for i=4, #log do
+ if all_symbols_stats[log[i]] == nil then
+ all_symbols_stats[log[i]] = {
+ name = log[i],
+ no_of_hits = 0,
+ spam_hits = 0,
+ ham_hits = 0,
+ spam_overall = 0
+ }
+ end
+
+ all_symbols_stats[log[i]].no_of_hits =
+ all_symbols_stats[log[i]].no_of_hits + 1
+
+ if is_spam then
+ all_symbols_stats[log[i]].spam_hits =
+ all_symbols_stats[log[i]].spam_hits + 1
+ else
+ all_symbols_stats[log[i]].ham_hits =
+ all_symbols_stats[log[i]].ham_hits + 1
+ end
+ end
+ end
+
+ -- Calculating file stats
+
+ file_stats.no_of_ham = no_of_ham
+ file_stats.no_of_spam = no_of_spam
+ file_stats.no_of_emails = no_of_emails
+ file_stats.true_positives = true_positives
+ file_stats.true_negatives = true_negatives
+
+ if no_of_emails > 0 then
+ file_stats.spam_percent = no_of_spam * 100 / no_of_emails
+ file_stats.ham_percent = no_of_ham * 100 / no_of_emails
+ file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
+ no_of_emails
+ end
+
+ if no_of_ham > 0 then
+ file_stats.false_positive_rate = false_positives * 100 / no_of_ham
+ end
+
+ if no_of_spam > 0 then
+ file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
+ end
+
+ file_stats.fscore = 2 * true_positives / (2
+ * true_positives
+ + false_positives
+ + false_negatives)
+
+ -- Calculating symbol stats
+
+ for _, symbol_stats in pairs(all_symbols_stats) do
+ symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
+ symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
+ symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
+ symbol_stats.spam_overall = symbol_stats.spam_percent /
+ (symbol_stats.spam_percent + symbol_stats.ham_percent)
+ end
+
+ return file_stats, all_symbols_stats
end
return utility
static gchar *logdir = NULL;
static gchar *output = "new.scores";
-static gdouble threshold = 15; /* Spam threshold */
static gboolean score_diff = false; /* Print score diff flag */
-static gint64 iters = 500; /* Perceptron max iterations */
-gdouble timeout = 60.0;
-
-/* TODO: think about adding the config file reading */
+static gchar *config = NULL;
+extern struct rspamd_main *rspamd_main;
+/* Defined in modules.c */
+extern module_t *modules[];
+extern worker_t *workers[];
static void rspamadm_rescore (gint argc, gchar **argv);
"Scores output locaiton", NULL},
{"diff", 'd', 0, G_OPTION_ARG_NONE, &score_diff,
"Print score diff", NULL},
- {"iters", 'i', 0, G_OPTION_ARG_INT64, &iters,
- "Max iterations for perceptron [Default: 500]", NULL},
- {"timeout", 't', 0, G_OPTION_ARG_DOUBLE, &timeout,
- "Timeout for connections [Default: 60]", NULL},
+ {"config", 'c', 0, G_OPTION_ARG_STRING, &config,
+ "Config file to use", NULL},
{NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL}
};
+static void
+config_logger (rspamd_mempool_t *pool, gpointer ud)
+{
+ struct rspamd_main *rm = ud;
+
+ rm->cfg->log_type = RSPAMD_LOG_CONSOLE;
+ rm->cfg->log_level = G_LOG_LEVEL_CRITICAL;
+
+ rspamd_set_logger (rm->cfg, g_quark_try_string ("main"), &rm->logger,
+ rm->server_pool);
+
+ if (rspamd_log_open_priv (rm->logger, rm->workers_uid, rm->workers_gid) ==
+ -1) {
+ fprintf (stderr, "Fatal error, cannot open logfile, exiting\n");
+ exit (EXIT_FAILURE);
+ }
+}
+
static const char *
rspamadm_rescore_help (gboolean full_help) {
"-l: path to logs directory\n"
"-o: scores output file location\n"
"-d: print scores diff\n"
- "-i: max iterations for perceptron\n"
- "-t: timeout for rspamc operations (default: 60)\n";
+ "-i: max iterations for perceptron\n";
} else {
help_str = "Estimate optimal symbol weights from log files";
}
GOptionContext *context;
GError *error = NULL;
lua_State *L;
- ucl_object_t *obj;
+ struct rspamd_config *cfg = rspamd_main->cfg, **pcfg;
+ gboolean ret = TRUE;
+ worker_t **pworker;
+ const gchar *confdir;
context = g_option_context_new (
"rescore - estimate optimal symbol weights from log files");
exit (EXIT_FAILURE);
}
- L = rspamd_lua_init ();
- rspamd_lua_set_path (L, NULL, ucl_vars);
-
- obj = ucl_object_typed_new (UCL_OBJECT);
-
- ucl_object_insert_key (obj, ucl_object_fromstring (logdir),
- "logdir", 0, false);
- ucl_object_insert_key (obj, ucl_object_fromstring (output),
- "output", 0, false);
- ucl_object_insert_key (obj, ucl_object_fromdouble (threshold),
- "threshold", 0, false);
- ucl_object_insert_key (obj, ucl_object_fromint (iters),
- "iters", 0, false);
- ucl_object_insert_key (obj, ucl_object_frombool (score_diff),
- "diff", 0, false);
- ucl_object_insert_key (obj, ucl_object_fromdouble (timeout),
- "timeout", 0, false);
-
- rspamadm_execute_lua_ucl_subr (L,
- argc,
- argv,
- obj,
- "rescore");
-
- lua_close (L);
- ucl_object_unref (obj);
+ if (config == NULL) {
+ if ((confdir = g_hash_table_lookup (ucl_vars, "CONFDIR")) == NULL) {
+ confdir = RSPAMD_CONFDIR;
+ }
+
+ config = g_strdup_printf ("%s%c%s", confdir, G_DIR_SEPARATOR,
+ "rspamd.conf");
+ }
+
+ pworker = &workers[0];
+ while (*pworker) {
+ /* Init string quarks */
+ (void) g_quark_from_static_string ((*pworker)->name);
+ pworker++;
+ }
+
+ cfg->cache = rspamd_symbols_cache_new (cfg);
+ cfg->compiled_modules = modules;
+ cfg->compiled_workers = workers;
+ cfg->cfg_name = config;
+
+ if (!rspamd_config_read (cfg, cfg->cfg_name, NULL,
+ config_logger, rspamd_main, ucl_vars)) {
+ ret = FALSE;
+ }
+ else {
+ /* Do post-load actions */
+ rspamd_lua_post_load_config (cfg);
+
+ if (!rspamd_init_filters (cfg, FALSE)) {
+ ret = FALSE;
+ }
+
+ if (ret) {
+ ret = rspamd_config_post_load (cfg, RSPAMD_CONFIG_INIT_SYMCACHE);
+ rspamd_symbols_cache_validate (cfg->cache,
+ cfg,
+ FALSE);
+ }
+ }
+
+ if (ret) {
+ L = cfg->lua_state;
+ rspamd_lua_set_path (L, cfg->rcl_obj, ucl_vars);
+ ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (cfg->cfg_name),
+ "config_path", 0, false);
+ ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (logdir),
+ "logdir", 0, false);
+ ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (output),
+ "output", 0, false);
+ ucl_object_insert_key (cfg->rcl_obj, ucl_object_frombool (score_diff),
+ "diff", 0, false);
+ pcfg = lua_newuserdata (L, sizeof (struct rspamd_config *));
+ rspamd_lua_setclass (L, "rspamd{config}", -1);
+ *pcfg = cfg;
+ lua_setglobal (L, "rspamd_config");
+ rspamadm_execute_lua_ucl_subr (L,
+ argc,
+ argv,
+ cfg->rcl_obj,
+ "rescore");
+ }
}
\ No newline at end of file