diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-08-07 12:19:35 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-08-07 12:19:35 +0100 |
commit | 416ba555dd278fa771ca862e174b578ad4bba958 (patch) | |
tree | e8ba5b1041770761aabd36556072295551794285 /lualib | |
parent | 460a82484915f5fcbf34f65194c3437fa2a4e0c7 (diff) | |
download | rspamd-416ba555dd278fa771ca862e174b578ad4bba958.tar.gz rspamd-416ba555dd278fa771ca862e174b578ad4bba958.zip |
[Minor] Remove unused utility, as it has been broken for ages
Diffstat (limited to 'lualib')
-rw-r--r-- | lualib/rescore_utility.lua | 230 | ||||
-rw-r--r-- | lualib/rspamadm/rescore.lua | 588 |
2 files changed, 0 insertions, 818 deletions
diff --git a/lualib/rescore_utility.lua b/lualib/rescore_utility.lua deleted file mode 100644 index 01e5aabcb..000000000 --- a/lualib/rescore_utility.lua +++ /dev/null @@ -1,230 +0,0 @@ -local lua_util = require "lua_util" -local rspamd_util = require "rspamd_util" -local fun = require "fun" - -local utility = {} - -function utility.get_all_symbols(logs, ignore_symbols) - -- Returns a list of all symbols - - local symbols_set = {} - - for _, line in pairs(logs) do - line = lua_util.rspamd_str_split(line, " ") - for i = 4, (#line - 1) do - line[i] = line[i]:gsub("%s+", "") - if not symbols_set[line[i]] then - symbols_set[line[i]] = true - end - end - end - - local all_symbols = {} - - for symbol, _ in pairs(symbols_set) do - if not ignore_symbols[symbol] then - all_symbols[#all_symbols + 1] = symbol - end - end - - table.sort(all_symbols) - - return all_symbols -end - -function utility.read_log_file(file) - - local lines = {} - local messages = {} - - local fd = assert(io.open(file, "r")) - local fname = string.gsub(file, "(.*/)(.*)", "%2") - - for line in fd:lines() do - local start, stop = string.find(line, fname .. ':') - - if start and stop then - table.insert(lines, string.sub(line, 1, start)) - table.insert(messages, string.sub(line, stop + 1, -1)) - end - end - - io.close(fd) - - return lines, messages -end - -function utility.get_all_logs(dirs) - -- Reads all log files in the directory and returns a list of logs. - - if type(dirs) == 'string' then - dirs = { dirs } - end - - local all_logs = {} - local all_messages = {} - - for _, dir in ipairs(dirs) do - if dir:sub(-1, -1) == "/" then - dir = dir:sub(1, -2) - local files = rspamd_util.glob(dir .. "/*.log") - for _, file in pairs(files) do - local logs, messages = utility.read_log_file(file) - for i = 1, #logs do - table.insert(all_logs, logs[i]) - table.insert(all_messages, messages[i]) - end - end - else - local logs, messages = utility.read_log_file(dir) - for i = 1, #logs do - table.insert(all_logs, logs[i]) - table.insert(all_messages, messages[i]) - end - end - end - - return all_logs, all_messages -end - -function utility.get_all_symbol_scores(conf, ignore_symbols) - local symbols = conf:get_symbols_scores() - - return fun.tomap(fun.map(function(name, elt) - return name, elt['score'] - end, fun.filter(function(name, elt) - return not ignore_symbols[name] - end, symbols))) -end - -function utility.generate_statistics_from_logs(logs, messages, threshold) - - -- Returns file_stats table and list of symbol_stats table. - - local file_stats = { - no_of_emails = 0, - no_of_spam = 0, - no_of_ham = 0, - spam_percent = 0, - ham_percent = 0, - true_positives = 0, - true_negatives = 0, - false_negative_rate = 0, - false_positive_rate = 0, - overall_accuracy = 0, - fscore = 0, - avg_scan_time = 0, - slowest_file = nil, - slowest = 0 - } - - local all_symbols_stats = {} - local all_fps = {} - local all_fns = {} - - local false_positives = 0 - local false_negatives = 0 - local true_positives = 0 - local true_negatives = 0 - local no_of_emails = 0 - local no_of_spam = 0 - local no_of_ham = 0 - - for i, log in ipairs(logs) do - log = lua_util.rspamd_str_trim(log) - log = lua_util.rspamd_str_split(log, " ") - local message = messages[i] - - local is_spam = (log[1] == "SPAM") - local score = tonumber(log[2]) - - no_of_emails = no_of_emails + 1 - - if is_spam then - no_of_spam = no_of_spam + 1 - else - no_of_ham = no_of_ham + 1 - end - - if is_spam and (score >= threshold) then - true_positives = true_positives + 1 - elseif is_spam and (score < threshold) then - false_negatives = false_negatives + 1 - table.insert(all_fns, message) - elseif not is_spam and (score >= threshold) then - false_positives = false_positives + 1 - table.insert(all_fps, message) - else - true_negatives = true_negatives + 1 - end - - for j = 4, (#log - 1) do - if all_symbols_stats[log[j]] == nil then - all_symbols_stats[log[j]] = { - name = message, - no_of_hits = 0, - spam_hits = 0, - ham_hits = 0, - spam_overall = 0 - } - end - local sym = log[j] - - all_symbols_stats[sym].no_of_hits = all_symbols_stats[sym].no_of_hits + 1 - - if is_spam then - all_symbols_stats[sym].spam_hits = all_symbols_stats[sym].spam_hits + 1 - else - all_symbols_stats[sym].ham_hits = all_symbols_stats[sym].ham_hits + 1 - end - - -- Find slowest message - if ((tonumber(log[#log]) or 0) > file_stats.slowest) then - file_stats.slowest = tonumber(log[#log]) - file_stats.slowest_file = message - end - end - end - - -- Calculating file stats - - file_stats.no_of_ham = no_of_ham - file_stats.no_of_spam = no_of_spam - file_stats.no_of_emails = no_of_emails - file_stats.true_positives = true_positives - file_stats.true_negatives = true_negatives - - if no_of_emails > 0 then - file_stats.spam_percent = no_of_spam * 100 / no_of_emails - file_stats.ham_percent = no_of_ham * 100 / no_of_emails - file_stats.overall_accuracy = (true_positives + true_negatives) * 100 / - no_of_emails - end - - if no_of_ham > 0 then - file_stats.false_positive_rate = false_positives * 100 / no_of_ham - end - - if no_of_spam > 0 then - file_stats.false_negative_rate = false_negatives * 100 / no_of_spam - end - - file_stats.fscore = 2 * true_positives / (2 - * true_positives - + false_positives - + false_negatives) - - -- Calculating symbol stats - - for _, symbol_stats in pairs(all_symbols_stats) do - symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam - symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham - symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails - symbol_stats.spam_overall = symbol_stats.spam_percent / - (symbol_stats.spam_percent + symbol_stats.ham_percent) - end - - return file_stats, all_symbols_stats, all_fps, all_fns -end - -return utility diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua deleted file mode 100644 index 6e372a419..000000000 --- a/lualib/rspamadm/rescore.lua +++ /dev/null @@ -1,588 +0,0 @@ ---[[ -Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -]]-- - ---[[ -local lua_util = require "lua_util" -local ucl = require "ucl" -local logger = require "rspamd_logger" -local rspamd_util = require "rspamd_util" -local argparse = require "argparse" -local rescore_utility = require "rescore_utility" - - -local opts -local ignore_symbols = { - ['DATE_IN_PAST'] =true, - ['DATE_IN_FUTURE'] = true, -} - -local parser = argparse() - :name "rspamadm rescore" - :description "Estimate optimal symbol weights from log files" - :help_description_margin(37) - -parser:option "-l --log" - :description "Log file or files (from rescore)" - :argname("<log>") - :args "*" -parser:option "-c --config" - :description "Path to config file" - :argname("<file>") - :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") -parser:option "-o --output" - :description "Output file" - :argname("<file>") - :default("new.scores") -parser:flag "-d --diff" - :description "Show differences in scores" -parser:flag "-v --verbose" - :description "Verbose output" -parser:flag "-z --freq" - :description "Display hit frequencies" -parser:option "-i --iters" - :description "Learn iterations" - :argname("<n>") - :convert(tonumber) - :default(10) -parser:option "-b --batch" - :description "Batch size" - :argname("<n>") - :convert(tonumber) - :default(100) -parser:option "-d --decay" - :description "Decay rate" - :argname("<n>") - :convert(tonumber) - :default(0.001) -parser:option "-m --momentum" - :description "Learn momentum" - :argname("<n>") - :convert(tonumber) - :default(0.1) -parser:option "-t --threads" - :description "Number of threads to use" - :argname("<n>") - :convert(tonumber) - :default(1) -parser:option "-o --optim" - :description "Optimisation algorithm" - :argname("<alg>") - :convert { - LBFGS = "LBFGS", - ADAM = "ADAM", - ADAGRAD = "ADAGRAD", - SGD = "SGD", - NAG = "NAG" - } - :default "ADAM" -parser:option "--ignore-symbol" - :description "Ignore symbol from logs" - :argname("<sym>") - :args "*" -parser:option "--penalty-weight" - :description "Add new penalty weight to test" - :argname("<n>") - :convert(tonumber) - :args "*" -parser:option "--learning-rate" - :description "Add new learning rate to test" - :argname("<n>") - :convert(tonumber) - :args "*" -parser:option "--spam_action" - :description "Spam action" - :argname("<act>") - :default("reject") -parser:option "--learning_rate_decay" - :description "Learn rate decay (for some algs)" - :argname("<n>") - :convert(tonumber) - :default(0.0) -parser:option "--weight_decay" - :description "Weight decay (for some algs)" - :argname("<n>") - :convert(tonumber) - :default(0.0) -parser:option "--l1" - :description "L1 regularization penalty" - :argname("<n>") - :convert(tonumber) - :default(0.0) -parser:option "--l2" - :description "L2 regularization penalty" - :argname("<n>") - :convert(tonumber) - :default(0.0) - -local function make_dataset_from_logs(logs, all_symbols, spam_score) - - local inputs = {} - local outputs = {} - - for _, log in pairs(logs) do - - log = lua_util.rspamd_str_split(log, " ") - - if log[1] == "SPAM" then - outputs[#outputs+1] = 1 - else - outputs[#outputs+1] = 0 - end - - local symbols_set = {} - - for i=4,#log do - if not ignore_symbols[ log[i] ] then - symbols_set[log[i] ] = true - end - end - - local input_vec = {} - for index, symbol in pairs(all_symbols) do - if symbols_set[symbol] then - input_vec[index] = 1 - else - input_vec[index] = 0 - end - end - - inputs[#inputs + 1] = input_vec - end - - return inputs,outputs -end - -local function init_weights(all_symbols, original_symbol_scores) -end - -local function shuffle(logs, messages) - - local size = #logs - for i = size, 1, -1 do - local rand = math.random(size) - logs[i], logs[rand] = logs[rand], logs[i] - messages[i], messages[rand] = messages[rand], messages[i] - end - -end - -local function split_logs(logs, messages, split_percent) - - if not split_percent then - split_percent = 60 - end - - local split_index = math.floor(#logs * split_percent / 100) - - local test_logs = {} - local train_logs = {} - local test_messages = {} - local train_messages = {} - - for i=1,split_index do - table.insert(train_logs, logs[i]) - table.insert(train_messages, messages[i]) - end - - for i=split_index + 1, #logs do - table.insert(test_logs, logs[i]) - table.insert(test_messages, messages[i]) - end - - return {train_logs,train_messages}, {test_logs,test_messages} -end - -local function stitch_new_scores(all_symbols, new_scores) - - local new_symbol_scores = {} - - for idx, symbol in pairs(all_symbols) do - new_symbol_scores[symbol] = new_scores[idx] - end - - return new_symbol_scores -end - - -local function update_logs(logs, symbol_scores) - - for i, log in ipairs(logs) do - - log = lua_util.rspamd_str_split(log, " ") - - local score = 0 - - for j=4,#log do - log[j] = log[j]:gsub("%s+", "") - score = score + (symbol_scores[log[j] ] or 0) - end - - log[2] = lua_util.round(score, 2) - - logs[i] = table.concat(log, " ") - end - - return logs -end - -local function write_scores(new_symbol_scores, file_path) - - local file = assert(io.open(file_path, "w")) - - local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl") - - file:write(new_scores_ucl) - - file:close() -end - -local function print_score_diff(new_symbol_scores, original_symbol_scores) - - logger.message(string.format("%-35s %-10s %-10s", - "SYMBOL", "OLD_SCORE", "NEW_SCORE")) - - for symbol, new_score in pairs(new_symbol_scores) do - logger.message(string.format("%-35s %-10s %-10s", - symbol, - original_symbol_scores[symbol] or 0, - lua_util.round(new_score, 2))) - end - - logger.message("\nClass changes \n") - for symbol, new_score in pairs(new_symbol_scores) do - if original_symbol_scores[symbol] ~= nil then - if (original_symbol_scores[symbol] > 0 and new_score < 0) or - (original_symbol_scores[symbol] < 0 and new_score > 0) then - logger.message(string.format("%-35s %-10s %-10s", - symbol, - original_symbol_scores[symbol] or 0, - lua_util.round(new_score, 2))) - end - end - end - -end - -local function calculate_fscore_from_weights(logs, messages, - all_symbols, - weights, - threshold) - - local new_symbol_scores = weights:clone() - - new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) - - logs = update_logs(logs, new_symbol_scores) - - local file_stats, _, all_fps, all_fns = - rescore_utility.generate_statistics_from_logs(logs, messages, threshold) - - return file_stats.fscore, all_fps, all_fns -end - -local function print_stats(logs, messages, threshold) - - local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, - messages, threshold) - - local file_stat_format = [=[ -F-score: %.2f -False positive rate: %.2f %% -False negative rate: %.2f %% -Overall accuracy: %.2f %% -Slowest message: %.2f (%s) -]=] - - logger.message("\nStatistics at threshold: " .. threshold) - - logger.message(string.format(file_stat_format, - file_stats.fscore, - file_stats.false_positive_rate, - file_stats.false_negative_rate, - file_stats.overall_accuracy, - file_stats.slowest, - file_stats.slowest_file)) - -end - --- training function -local function train(dataset, opt, model, criterion, epoch, - all_symbols, spam_threshold, initial_weights) -end - -local learning_rates = { - 0.01 -} -local penalty_weights = { - 0 -} - -local function get_threshold() - local actions = rspamd_config:get_all_actions() - - if opts['spam-action'] then - return (actions[opts['spam-action'] ] or 0),actions['reject'] - end - return (actions['add header'] or actions['rewrite subject'] - or actions['reject']), actions['reject'] -end - -local function handler(args) - opts = parser:parse(args) - if not opts['log'] then - parser:error('no log specified') - end - - local _r,err = rspamd_config:load_ucl(opts['config']) - - if not _r then - logger.errx('cannot parse %s: %s', opts['config'], err) - os.exit(1) - end - - _r,err = rspamd_config:parse_rcl({'logging', 'worker'}) - if not _r then - logger.errx('cannot process %s: %s', opts['config'], err) - os.exit(1) - end - - local threshold,reject_score = get_threshold() - local logs,messages = rescore_utility.get_all_logs(opts['log']) - - if opts['ignore-symbol'] then - local function add_ignore(s) - ignore_symbols[s] = true - end - if type(opts['ignore-symbol']) == 'table' then - for _,s in ipairs(opts['ignore-symbol']) do - add_ignore(s) - end - else - add_ignore(opts['ignore-symbol']) - end - end - - if opts['learning-rate'] then - learning_rates = {} - - local function add_rate(r) - if tonumber(r) then - table.insert(learning_rates, tonumber(r)) - end - end - if type(opts['learning-rate']) == 'table' then - for _,s in ipairs(opts['learning-rate']) do - add_rate(s) - end - else - add_rate(opts['learning-rate']) - end - end - - if opts['penalty-weight'] then - penalty_weights = {} - - local function add_weight(r) - if tonumber(r) then - table.insert(penalty_weights, tonumber(r)) - end - end - if type(opts['penalty-weight']) == 'table' then - for _,s in ipairs(opts['penalty-weight']) do - add_weight(s) - end - else - add_weight(opts['penalty-weight']) - end - end - - local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols) - local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config, - ignore_symbols) - - -- Display hit frequencies - if opts['freq'] then - local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, - messages, - threshold) - local t = {} - for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end - - local function compare_symbols(a, b) - if (a.spam_overall ~= b.spam_overall) then - return b.spam_overall < a.spam_overall - end - if (b.spam_hits ~= a.spam_hits) then - return b.spam_hits < a.spam_hits - end - return b.ham_hits < a.ham_hits - end - table.sort(t, compare_symbols) - - logger.message(string.format("%-40s %6s %6s %6s %6s %6s %6s %6s", - "NAME", "HITS", "HAM", "HAM%", "SPAM", "SPAM%", "S/O", "OVER%")) - for _, symbol_stats in pairs(t) do - logger.message( - string.format("%-40s %6d %6d %6.2f %6d %6.2f %6.2f %6.2f", - symbol_stats.name, - symbol_stats.no_of_hits, - symbol_stats.ham_hits, - lua_util.round(symbol_stats.ham_percent,2), - symbol_stats.spam_hits, - lua_util.round(symbol_stats.spam_percent,2), - lua_util.round(symbol_stats.spam_overall,2), - lua_util.round(symbol_stats.overall, 2) - ) - ) - end - - -- Print file statistics - print_stats(logs, messages, threshold) - - -- Work out how many symbols weren't seen in the corpus - local symbols_no_hits = {} - local total_symbols = 0 - for sym in pairs(original_symbol_scores) do - total_symbols = total_symbols + 1 - if (all_symbols_stats[sym] == nil) then - table.insert(symbols_no_hits, sym) - end - end - if (#symbols_no_hits > 0) then - table.sort(symbols_no_hits) - -- Calculate percentage of rules with no hits - local nhpct = lua_util.round((#symbols_no_hits/total_symbols)*100,2) - logger.message( - string.format('\nFound %s (%-.2f%%) symbols out of %s with no hits in corpus:', - #symbols_no_hits, nhpct, total_symbols - ) - ) - for _, symbol in pairs(symbols_no_hits) do - logger.messagex('%s', symbol) - end - end - - return - end - - shuffle(logs, messages) - local train_logs, validation_logs = split_logs(logs, messages,70) - local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50) - - local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score) - -- Start of perceptron training - local input_size = #all_symbols - - local linear_module = nn.Linear(input_size, 1, false) - local activation = nn.Sigmoid() - - local perceptron = nn.Sequential() - perceptron:add(linear_module) - perceptron:add(activation) - - local criterion = nn.MSECriterion() - --criterion.sizeAverage = false - - local best_fscore = -math.huge - local best_weights = linear_module.weight[1]:clone() - local best_learning_rate - local best_weight_decay - local all_fps - local all_fns - - for _,lr in ipairs(learning_rates) do - for _,wd in ipairs(penalty_weights) do - linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores) - local initial_weights = linear_module.weight[1]:clone() - opts.learning_rate = lr - opts.weight_decay = wd - for i=1,tonumber(opts.iters) do - train(dataset, opts, perceptron, criterion, i, all_symbols, threshold, - initial_weights) - end - - local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1], - cv_logs[2], - all_symbols, - linear_module.weight[1], - threshold) - - logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s", - fscore, lr, wd) - - if best_fscore < fscore then - best_learning_rate = lr - best_weight_decay = wd - best_fscore = fscore - best_weights = linear_module.weight[1]:clone() - all_fps = fps - all_fns = fns - end - end - end - - -- End perceptron training - - local new_symbol_scores = best_weights - - new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) - - if opts["output"] then - write_scores(new_symbol_scores, opts["output"]) - end - - if opts["diff"] then - print_score_diff(new_symbol_scores, original_symbol_scores) - end - - -- Pre-rescore test stats - logger.message("\n\nPre-rescore test stats\n") - test_logs[1] = update_logs(test_logs[1], original_symbol_scores) - print_stats(test_logs[1], test_logs[2], threshold) - - -- Post-rescore test stats - test_logs[1] = update_logs(test_logs[1], new_symbol_scores) - logger.message("\n\nPost-rescore test stats\n") - print_stats(test_logs[1], test_logs[2], threshold) - - logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s', - best_fscore, best_learning_rate, best_weight_decay) - - -- Show all FPs/FNs, useful for corpus checking and rule creation/modification - if (all_fps and #all_fps > 0) then - logger.message("\nFalse-Positives:") - for _, fp in pairs(all_fps) do - logger.messagex('%s', fp) - end - end - - if (all_fns and #all_fns > 0) then - logger.message("\nFalse-Negatives:") - for _, fn in pairs(all_fns) do - logger.messagex('%s', fn) - end - end -end - - -return { - handler = handler, - description = parser._description, - name = 'rescore' -} ---]] - -return nil |