aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-06-20 21:02:18 +0600
committerGitHub <noreply@github.com>2024-06-20 21:02:18 +0600
commitb44099c96e279a0d60b0688e84e0ef6293194c59 (patch)
tree52d9dde0b767a0b6c92bde68b798cdf7a909608f
parentcd92d8bb1d412abc4aaaea57efb55145f9f49b8b (diff)
parent4cc5bf76576e1b11d451e75cc37216ac86c86eb7 (diff)
downloadrspamd-b44099c96e279a0d60b0688e84e0ef6293194c59.tar.gz
rspamd-b44099c96e279a0d60b0688e84e0ef6293194c59.zip
Merge pull request #5011 from rspamd/vstakhov-bayes-experiments
Improve bayes performance by setting the default window size to 2
-rw-r--r--lualib/lua_util.lua4
-rw-r--r--lualib/rspamadm/classifier_test.lua227
-rw-r--r--src/client/rspamc.cxx49
-rw-r--r--src/libstat/tokenizers/osb.c2
4 files changed, 271 insertions, 11 deletions
diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua
index ac755182b..650ad5db1 100644
--- a/lualib/lua_util.lua
+++ b/lualib/lua_util.lua
@@ -389,8 +389,8 @@ end
--]]
local unpack_function = table.unpack or unpack
-exports.unpack = function(t)
- return unpack_function(t)
+exports.unpack = function(...)
+ return unpack_function(...)
end
--[[[
diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua
new file mode 100644
index 000000000..21af14fc1
--- /dev/null
+++ b/lualib/rspamadm/classifier_test.lua
@@ -0,0 +1,227 @@
+local rspamd_util = require "rspamd_util"
+local lua_util = require "lua_util"
+local argparse = require "argparse"
+local ucl = require "ucl"
+local rspamd_logger = require "rspamd_logger"
+
+local parser = argparse()
+ :name "rspamadm classifier_test"
+ :description "Learn bayes classifier and evaluate its performance"
+ :help_description_margin(32)
+
+parser:option "-H --ham"
+ :description("Ham directory")
+ :argname("<dir>")
+parser:option "-S --spam"
+ :description("Spam directory")
+ :argname("<dir>")
+parser:flag "-n --no-learning"
+ :description("Do not learn classifier")
+parser:option "--nconns"
+ :description("Number of parallel connections")
+ :argname("<N>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-t --timeout"
+ :description("Timeout for client connections")
+ :argname("<sec>")
+ :convert(tonumber)
+ :default(60)
+parser:option "-c --connect"
+ :description("Connect to specific host")
+ :argname("<host>")
+ :default('localhost:11334')
+parser:option "-r --rspamc"
+ :description("Use specific rspamc path")
+ :argname("<path>")
+ :default('rspamc')
+parser:option "-c --cv-fraction"
+ :description("Use specific fraction for cross-validation")
+ :argname("<fraction>")
+ :convert(tonumber)
+ :default('0.7')
+
+local opts
+
+-- Utility function to split a table into two parts randomly
+local function split_table(t, fraction)
+ local shuffled = {}
+ for _, v in ipairs(t) do
+ local pos = math.random(1, #shuffled + 1)
+ table.insert(shuffled, pos, v)
+ end
+ local split_point = math.floor(#shuffled * tonumber(fraction))
+ local part1 = { lua_util.unpack(shuffled, 1, split_point) }
+ local part2 = { lua_util.unpack(shuffled, split_point + 1) }
+ return part1, part2
+end
+
+-- Utility function to get all files in a directory
+local function get_files(dir)
+ return rspamd_util.glob(dir .. '/*')
+end
+
+local function list_to_file(list, fname)
+ local out = assert(io.open(fname, "w"))
+ for _, v in ipairs(list) do
+ out:write(v)
+ out:write("\n")
+ end
+ out:close()
+end
+
+-- Function to train the classifier with given files
+local function train_classifier(files, command)
+ local fname = os.tmpname()
+ list_to_file(files, fname)
+ local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s --files-list=%s",
+ opts.rspamc, opts.connect, opts.nconns, opts.timeout, command, fname)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ os.remove(fname)
+end
+
+-- Function to classify files and return results
+local function classify_files(files)
+ local fname = os.tmpname()
+ list_to_file(files, fname)
+
+ local settings_header = '--header Settings=\"{symbols_enabled=[BAYES_SPAM, BAYES_HAM]}\"'
+ local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s",
+ opts.rspamc,
+ settings_header,
+ opts.connect,
+ opts.nconns,
+ opts.timeout, fname)
+ local result = assert(io.popen(rspamc_command))
+ local results = {}
+ for line in result:lines() do
+ local ucl_parser = ucl.parser()
+ local is_good, err = ucl_parser:parse_string(line)
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ os.remove(fname)
+ return nil
+ end
+ local obj = ucl_parser:get_object()
+ local file = obj.filename
+ local symbols = obj.symbols or {}
+
+ if symbols["BAYES_SPAM"] then
+ table.insert(results, { result = "spam", file = file })
+ elseif symbols["BAYES_HAM"] then
+ table.insert(results, { result = "ham", file = file })
+ end
+ end
+
+ os.remove(fname)
+
+ return results
+end
+
+-- Function to evaluate classifier performance
+local function evaluate_results(results, spam_label, ham_label,
+ known_spam_files, known_ham_files, total_cv_files, elapsed)
+ local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
+ for _, res in ipairs(results) do
+ if res.result == spam_label then
+ if known_spam_files[res.file] then
+ true_positives = true_positives + 1
+ elseif known_ham_files[res.file] then
+ false_positives = false_positives + 1
+ end
+ total = total + 1
+ elseif res.result == ham_label then
+ if known_spam_files[res.file] then
+ false_negatives = false_negatives + 1
+ elseif known_ham_files[res.file] then
+ true_negatives = true_negatives + 1
+ end
+ total = total + 1
+ end
+ end
+
+ local accuracy = (true_positives + true_negatives) / total
+ local precision = true_positives / (true_positives + false_positives)
+ local recall = true_positives / (true_positives + false_negatives)
+ local f1_score = 2 * (precision * recall) / (precision + recall)
+
+ print(string.format("%-20s %-10s", "Metric", "Value"))
+ print(string.rep("-", 30))
+ print(string.format("%-20s %-10d", "True Positives", true_positives))
+ print(string.format("%-20s %-10d", "False Positives", false_positives))
+ print(string.format("%-20s %-10d", "True Negatives", true_negatives))
+ print(string.format("%-20s %-10d", "False Negatives", false_negatives))
+ print(string.format("%-20s %-10.2f", "Accuracy", accuracy))
+ print(string.format("%-20s %-10.2f", "Precision", precision))
+ print(string.format("%-20s %-10.2f", "Recall", recall))
+ print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
+ print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100))
+ print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed))
+end
+
+local function handler(args)
+ opts = parser:parse(args)
+ local ham_directory = opts['ham']
+ local spam_directory = opts['spam']
+ -- Get all files
+ local spam_files = get_files(spam_directory)
+ local known_spam_files = lua_util.list_to_hash(spam_files)
+ local ham_files = get_files(ham_directory)
+ local known_ham_files = lua_util.list_to_hash(ham_files)
+
+ -- Split files into training and cross-validation sets
+
+ local train_spam, cv_spam = split_table(spam_files, opts.cv_fraction)
+ local train_ham, cv_ham = split_table(ham_files, opts.cv_fraction)
+
+ print(string.format("Spam: %d train files, %d cv files; ham: %d train files, %d cv files",
+ #train_spam, #cv_spam, #train_ham, #cv_ham))
+ if not opts.no_learning then
+ -- Train classifier
+ local t, train_spam_time, train_ham_time
+ print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
+ t = rspamd_util.get_time()
+ train_classifier(train_spam, "learn_spam")
+ train_spam_time = rspamd_util.get_time() - t
+ print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
+ t = rspamd_util.get_time()
+ train_classifier(train_ham, "learn_ham")
+ train_ham_time = rspamd_util.get_time() - t
+ print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds",
+ #train_spam, train_spam_time, #train_ham, train_ham_time))
+ end
+
+ -- Classify cross-validation files
+ local cv_files = {}
+ for _, file in ipairs(cv_spam) do
+ table.insert(cv_files, file)
+ end
+ for _, file in ipairs(cv_ham) do
+ table.insert(cv_files, file)
+ end
+
+ -- Shuffle cross-validation files
+ cv_files = split_table(cv_files, 1)
+
+ print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
+ -- Get classification results
+ local t = rspamd_util.get_time()
+ local results = classify_files(cv_files)
+ local elapsed = rspamd_util.get_time() - t
+
+ -- Evaluate results
+ evaluate_results(results, "spam", "ham",
+ known_spam_files,
+ known_ham_files,
+ #cv_files,
+ elapsed)
+
+end
+
+return {
+ name = 'classifiertest',
+ aliases = { 'classifier_test' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/src/client/rspamc.cxx b/src/client/rspamc.cxx
index af3276ede..5a9a654e4 100644
--- a/src/client/rspamc.cxx
+++ b/src/client/rspamc.cxx
@@ -26,6 +26,8 @@
#include <optional>
#include <algorithm>
#include <functional>
+#include <iostream>
+#include <fstream>
#include <cstdint>
#include <cstdio>
#include <cmath>
@@ -86,6 +88,7 @@ static gboolean skip_images = FALSE;
static gboolean skip_attachments = FALSE;
static const char *pubkey = nullptr;
static const char *user_agent = "rspamc";
+static const char *files_list = nullptr;
std::vector<GPid> children;
static GPatternSpec **exclude_compiled = nullptr;
@@ -176,6 +179,8 @@ static GOptionEntry entries[] =
"Skip attachments when learning/unlearning fuzzy", nullptr},
{"user-agent", 'U', 0, G_OPTION_ARG_STRING, &user_agent,
"Use specific User-Agent instead of \"rspamc\"", nullptr},
+ {"files-list", '\0', 0, G_OPTION_ARG_FILENAME, &files_list,
+ "Read one or more newline separated filenames to scan from file", nullptr},
{nullptr, 0, 0, G_OPTION_ARG_NONE, nullptr, nullptr, nullptr}};
static void rspamc_symbols_output(FILE *out, ucl_object_t *obj);
@@ -2290,7 +2295,7 @@ int main(int argc, char **argv, char **env)
add_options(kwattrs);
auto cmd = maybe_cmd.value();
- if (start_argc == argc) {
+ if (start_argc == argc && files_list == nullptr) {
/* Do command without input or with stdin */
if (empty_input) {
rspamc_process_input(event_loop, cmd, nullptr, "empty", kwattrs);
@@ -2302,29 +2307,57 @@ int main(int argc, char **argv, char **env)
else {
auto cur_req = 0;
+ /* Process files from arguments and `files_list` */
+ std::vector<std::string> files;
+ files.reserve(argc - start_argc);
+
for (auto i = start_argc; i < argc; i++) {
+ files.emplace_back(argv[i]);
+ }
+
+ /* If we have list of files, read it and enrich our list */
+ if (files_list) {
+ std::ifstream in_files(files_list);
+ if (!in_files.is_open()) {
+ rspamc_print(stderr, "cannot open file {}\n", files_list);
+ exit(EXIT_FAILURE);
+ }
+ std::string line;
+ while (std::getline(in_files, line)) {
+ /* Trim spaces before inserting */
+ line.erase(0, line.find_first_not_of(" \n\r\t"));
+ line.erase(line.find_last_not_of(" \n\r\t") + 1);
+
+ /* Ignore empty lines */
+ if (!line.empty()) {
+ files.emplace_back(line);
+ }
+ }
+ }
+
+ for (const auto &file: files) {
if (cmd.cmd == RSPAMC_COMMAND_FUZZY_DELHASH) {
- add_client_header(kwattrs, "Hash", argv[i]);
+ add_client_header(kwattrs, "Hash", file.c_str());
}
else {
struct stat st;
- if (stat(argv[i], &st) == -1) {
- rspamc_print(stderr, "cannot stat file {}\n", argv[i]);
+ if (stat(file.c_str(), &st) == -1) {
+ rspamc_print(stderr, "cannot stat file {}\n", file);
exit(EXIT_FAILURE);
}
if (S_ISDIR(st.st_mode)) {
/* Directories are processed with a separate limit */
- rspamc_process_dir(event_loop, cmd, argv[i], kwattrs);
+ rspamc_process_dir(event_loop, cmd, file.c_str(), kwattrs);
cur_req = 0;
}
else {
- in = fopen(argv[i], "r");
+ in = fopen(file.c_str(), "r");
if (in == nullptr) {
- rspamc_print(stderr, "cannot open file {}\n", argv[i]);
+ rspamc_print(stderr, "cannot open file {}\n", file);
exit(EXIT_FAILURE);
}
- rspamc_process_input(event_loop, cmd, in, argv[i], kwattrs);
+ rspamc_process_input(event_loop, cmd, in, file.c_str(), kwattrs);
cur_req++;
fclose(in);
}
diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c
index 039ead231..0bc3414a5 100644
--- a/src/libstat/tokenizers/osb.c
+++ b/src/libstat/tokenizers/osb.c
@@ -23,7 +23,7 @@
#include "libmime/lang_detection.h"
/* Size for features pipe */
-#define DEFAULT_FEATURE_WINDOW_SIZE 5
+#define DEFAULT_FEATURE_WINDOW_SIZE 2
#define DEFAULT_OSB_VERSION 2
static const int primes[] = {