diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-06-20 21:02:18 +0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-20 21:02:18 +0600 |
commit | b44099c96e279a0d60b0688e84e0ef6293194c59 (patch) | |
tree | 52d9dde0b767a0b6c92bde68b798cdf7a909608f | |
parent | cd92d8bb1d412abc4aaaea57efb55145f9f49b8b (diff) | |
parent | 4cc5bf76576e1b11d451e75cc37216ac86c86eb7 (diff) | |
download | rspamd-b44099c96e279a0d60b0688e84e0ef6293194c59.tar.gz rspamd-b44099c96e279a0d60b0688e84e0ef6293194c59.zip |
Merge pull request #5011 from rspamd/vstakhov-bayes-experiments
Improve bayes performance by setting the default window size to 2
-rw-r--r-- | lualib/lua_util.lua | 4 | ||||
-rw-r--r-- | lualib/rspamadm/classifier_test.lua | 227 | ||||
-rw-r--r-- | src/client/rspamc.cxx | 49 | ||||
-rw-r--r-- | src/libstat/tokenizers/osb.c | 2 |
4 files changed, 271 insertions, 11 deletions
diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua index ac755182b..650ad5db1 100644 --- a/lualib/lua_util.lua +++ b/lualib/lua_util.lua @@ -389,8 +389,8 @@ end --]] local unpack_function = table.unpack or unpack -exports.unpack = function(t) - return unpack_function(t) +exports.unpack = function(...) + return unpack_function(...) end --[[[ diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua new file mode 100644 index 000000000..21af14fc1 --- /dev/null +++ b/lualib/rspamadm/classifier_test.lua @@ -0,0 +1,227 @@ +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local argparse = require "argparse" +local ucl = require "ucl" +local rspamd_logger = require "rspamd_logger" + +local parser = argparse() + :name "rspamadm classifier_test" + :description "Learn bayes classifier and evaluate its performance" + :help_description_margin(32) + +parser:option "-H --ham" + :description("Ham directory") + :argname("<dir>") +parser:option "-S --spam" + :description("Spam directory") + :argname("<dir>") +parser:flag "-n --no-learning" + :description("Do not learn classifier") +parser:option "--nconns" + :description("Number of parallel connections") + :argname("<N>") + :convert(tonumber) + :default(10) +parser:option "-t --timeout" + :description("Timeout for client connections") + :argname("<sec>") + :convert(tonumber) + :default(60) +parser:option "-c --connect" + :description("Connect to specific host") + :argname("<host>") + :default('localhost:11334') +parser:option "-r --rspamc" + :description("Use specific rspamc path") + :argname("<path>") + :default('rspamc') +parser:option "-c --cv-fraction" + :description("Use specific fraction for cross-validation") + :argname("<fraction>") + :convert(tonumber) + :default('0.7') + +local opts + +-- Utility function to split a table into two parts randomly +local function split_table(t, fraction) + local shuffled = {} + for _, v in ipairs(t) do + local pos = math.random(1, #shuffled + 1) + table.insert(shuffled, pos, v) + end + local split_point = math.floor(#shuffled * tonumber(fraction)) + local part1 = { lua_util.unpack(shuffled, 1, split_point) } + local part2 = { lua_util.unpack(shuffled, split_point + 1) } + return part1, part2 +end + +-- Utility function to get all files in a directory +local function get_files(dir) + return rspamd_util.glob(dir .. '/*') +end + +local function list_to_file(list, fname) + local out = assert(io.open(fname, "w")) + for _, v in ipairs(list) do + out:write(v) + out:write("\n") + end + out:close() +end + +-- Function to train the classifier with given files +local function train_classifier(files, command) + local fname = os.tmpname() + list_to_file(files, fname) + local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s --files-list=%s", + opts.rspamc, opts.connect, opts.nconns, opts.timeout, command, fname) + local result = assert(io.popen(rspamc_command)) + result = result:read("*all") + os.remove(fname) +end + +-- Function to classify files and return results +local function classify_files(files) + local fname = os.tmpname() + list_to_file(files, fname) + + local settings_header = '--header Settings=\"{symbols_enabled=[BAYES_SPAM, BAYES_HAM]}\"' + local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s", + opts.rspamc, + settings_header, + opts.connect, + opts.nconns, + opts.timeout, fname) + local result = assert(io.popen(rspamc_command)) + local results = {} + for line in result:lines() do + local ucl_parser = ucl.parser() + local is_good, err = ucl_parser:parse_string(line) + if not is_good then + rspamd_logger.errx("Parser error: %1", err) + os.remove(fname) + return nil + end + local obj = ucl_parser:get_object() + local file = obj.filename + local symbols = obj.symbols or {} + + if symbols["BAYES_SPAM"] then + table.insert(results, { result = "spam", file = file }) + elseif symbols["BAYES_HAM"] then + table.insert(results, { result = "ham", file = file }) + end + end + + os.remove(fname) + + return results +end + +-- Function to evaluate classifier performance +local function evaluate_results(results, spam_label, ham_label, + known_spam_files, known_ham_files, total_cv_files, elapsed) + local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0 + for _, res in ipairs(results) do + if res.result == spam_label then + if known_spam_files[res.file] then + true_positives = true_positives + 1 + elseif known_ham_files[res.file] then + false_positives = false_positives + 1 + end + total = total + 1 + elseif res.result == ham_label then + if known_spam_files[res.file] then + false_negatives = false_negatives + 1 + elseif known_ham_files[res.file] then + true_negatives = true_negatives + 1 + end + total = total + 1 + end + end + + local accuracy = (true_positives + true_negatives) / total + local precision = true_positives / (true_positives + false_positives) + local recall = true_positives / (true_positives + false_negatives) + local f1_score = 2 * (precision * recall) / (precision + recall) + + print(string.format("%-20s %-10s", "Metric", "Value")) + print(string.rep("-", 30)) + print(string.format("%-20s %-10d", "True Positives", true_positives)) + print(string.format("%-20s %-10d", "False Positives", false_positives)) + print(string.format("%-20s %-10d", "True Negatives", true_negatives)) + print(string.format("%-20s %-10d", "False Negatives", false_negatives)) + print(string.format("%-20s %-10.2f", "Accuracy", accuracy)) + print(string.format("%-20s %-10.2f", "Precision", precision)) + print(string.format("%-20s %-10.2f", "Recall", recall)) + print(string.format("%-20s %-10.2f", "F1 Score", f1_score)) + print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100)) + print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed)) +end + +local function handler(args) + opts = parser:parse(args) + local ham_directory = opts['ham'] + local spam_directory = opts['spam'] + -- Get all files + local spam_files = get_files(spam_directory) + local known_spam_files = lua_util.list_to_hash(spam_files) + local ham_files = get_files(ham_directory) + local known_ham_files = lua_util.list_to_hash(ham_files) + + -- Split files into training and cross-validation sets + + local train_spam, cv_spam = split_table(spam_files, opts.cv_fraction) + local train_ham, cv_ham = split_table(ham_files, opts.cv_fraction) + + print(string.format("Spam: %d train files, %d cv files; ham: %d train files, %d cv files", + #train_spam, #cv_spam, #train_ham, #cv_ham)) + if not opts.no_learning then + -- Train classifier + local t, train_spam_time, train_ham_time + print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns)) + t = rspamd_util.get_time() + train_classifier(train_spam, "learn_spam") + train_spam_time = rspamd_util.get_time() - t + print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns)) + t = rspamd_util.get_time() + train_classifier(train_ham, "learn_ham") + train_ham_time = rspamd_util.get_time() - t + print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds", + #train_spam, train_spam_time, #train_ham, train_ham_time)) + end + + -- Classify cross-validation files + local cv_files = {} + for _, file in ipairs(cv_spam) do + table.insert(cv_files, file) + end + for _, file in ipairs(cv_ham) do + table.insert(cv_files, file) + end + + -- Shuffle cross-validation files + cv_files = split_table(cv_files, 1) + + print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns)) + -- Get classification results + local t = rspamd_util.get_time() + local results = classify_files(cv_files) + local elapsed = rspamd_util.get_time() - t + + -- Evaluate results + evaluate_results(results, "spam", "ham", + known_spam_files, + known_ham_files, + #cv_files, + elapsed) + +end + +return { + name = 'classifiertest', + aliases = { 'classifier_test' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/src/client/rspamc.cxx b/src/client/rspamc.cxx index af3276ede..5a9a654e4 100644 --- a/src/client/rspamc.cxx +++ b/src/client/rspamc.cxx @@ -26,6 +26,8 @@ #include <optional> #include <algorithm> #include <functional> +#include <iostream> +#include <fstream> #include <cstdint> #include <cstdio> #include <cmath> @@ -86,6 +88,7 @@ static gboolean skip_images = FALSE; static gboolean skip_attachments = FALSE; static const char *pubkey = nullptr; static const char *user_agent = "rspamc"; +static const char *files_list = nullptr; std::vector<GPid> children; static GPatternSpec **exclude_compiled = nullptr; @@ -176,6 +179,8 @@ static GOptionEntry entries[] = "Skip attachments when learning/unlearning fuzzy", nullptr}, {"user-agent", 'U', 0, G_OPTION_ARG_STRING, &user_agent, "Use specific User-Agent instead of \"rspamc\"", nullptr}, + {"files-list", '\0', 0, G_OPTION_ARG_FILENAME, &files_list, + "Read one or more newline separated filenames to scan from file", nullptr}, {nullptr, 0, 0, G_OPTION_ARG_NONE, nullptr, nullptr, nullptr}}; static void rspamc_symbols_output(FILE *out, ucl_object_t *obj); @@ -2290,7 +2295,7 @@ int main(int argc, char **argv, char **env) add_options(kwattrs); auto cmd = maybe_cmd.value(); - if (start_argc == argc) { + if (start_argc == argc && files_list == nullptr) { /* Do command without input or with stdin */ if (empty_input) { rspamc_process_input(event_loop, cmd, nullptr, "empty", kwattrs); @@ -2302,29 +2307,57 @@ int main(int argc, char **argv, char **env) else { auto cur_req = 0; + /* Process files from arguments and `files_list` */ + std::vector<std::string> files; + files.reserve(argc - start_argc); + for (auto i = start_argc; i < argc; i++) { + files.emplace_back(argv[i]); + } + + /* If we have list of files, read it and enrich our list */ + if (files_list) { + std::ifstream in_files(files_list); + if (!in_files.is_open()) { + rspamc_print(stderr, "cannot open file {}\n", files_list); + exit(EXIT_FAILURE); + } + std::string line; + while (std::getline(in_files, line)) { + /* Trim spaces before inserting */ + line.erase(0, line.find_first_not_of(" \n\r\t")); + line.erase(line.find_last_not_of(" \n\r\t") + 1); + + /* Ignore empty lines */ + if (!line.empty()) { + files.emplace_back(line); + } + } + } + + for (const auto &file: files) { if (cmd.cmd == RSPAMC_COMMAND_FUZZY_DELHASH) { - add_client_header(kwattrs, "Hash", argv[i]); + add_client_header(kwattrs, "Hash", file.c_str()); } else { struct stat st; - if (stat(argv[i], &st) == -1) { - rspamc_print(stderr, "cannot stat file {}\n", argv[i]); + if (stat(file.c_str(), &st) == -1) { + rspamc_print(stderr, "cannot stat file {}\n", file); exit(EXIT_FAILURE); } if (S_ISDIR(st.st_mode)) { /* Directories are processed with a separate limit */ - rspamc_process_dir(event_loop, cmd, argv[i], kwattrs); + rspamc_process_dir(event_loop, cmd, file.c_str(), kwattrs); cur_req = 0; } else { - in = fopen(argv[i], "r"); + in = fopen(file.c_str(), "r"); if (in == nullptr) { - rspamc_print(stderr, "cannot open file {}\n", argv[i]); + rspamc_print(stderr, "cannot open file {}\n", file); exit(EXIT_FAILURE); } - rspamc_process_input(event_loop, cmd, in, argv[i], kwattrs); + rspamc_process_input(event_loop, cmd, in, file.c_str(), kwattrs); cur_req++; fclose(in); } diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c index 039ead231..0bc3414a5 100644 --- a/src/libstat/tokenizers/osb.c +++ b/src/libstat/tokenizers/osb.c @@ -23,7 +23,7 @@ #include "libmime/lang_detection.h" /* Size for features pipe */ -#define DEFAULT_FEATURE_WINDOW_SIZE 5 +#define DEFAULT_FEATURE_WINDOW_SIZE 2 #define DEFAULT_OSB_VERSION 2 static const int primes[] = { |