aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.overcommit.yml3
-rw-r--r--conf/statistic.conf4
-rw-r--r--lualib/lua_bayes_redis.lua67
-rw-r--r--lualib/redis_scripts/bayes_cache_learn.lua17
-rw-r--r--lualib/redis_scripts/bayes_classify.lua75
-rw-r--r--lualib/redis_scripts/bayes_learn.lua55
-rw-r--r--src/client/rspamc.cxx116
-rw-r--r--src/controller.c103
-rw-r--r--src/libserver/cfg_file.h28
-rw-r--r--src/libserver/cfg_rcl.cxx193
-rw-r--r--src/libserver/cfg_utils.cxx186
-rw-r--r--src/libserver/task.c43
-rw-r--r--src/libserver/task.h6
-rw-r--r--src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md451
-rw-r--r--src/libstat/backends/cdb_backend.cxx13
-rw-r--r--src/libstat/backends/mmaped_file.c10
-rw-r--r--src/libstat/backends/redis_backend.cxx537
-rw-r--r--src/libstat/backends/sqlite3_backend.c7
-rw-r--r--src/libstat/classifiers/bayes.c652
-rw-r--r--src/libstat/classifiers/classifiers.h14
-rw-r--r--src/libstat/learn_cache/redis_cache.cxx84
-rw-r--r--src/libstat/stat_api.h54
-rw-r--r--src/libstat/stat_config.c11
-rw-r--r--src/libstat/stat_process.c620
-rw-r--r--src/plugins/lua/bayes_expiry.lua182
-rw-r--r--test/functional/cases/110_statistics/300-multiclass-redis.robot42
-rw-r--r--test/functional/cases/110_statistics/320-multiclass-peruser.robot31
-rw-r--r--test/functional/cases/110_statistics/multiclass_lib.robot169
-rw-r--r--test/functional/configs/multiclass_bayes.conf129
-rw-r--r--test/functional/lib/rspamd.robot17
-rw-r--r--test/functional/messages/newsletter.eml50
-rw-r--r--test/functional/messages/transactional.eml18
32 files changed, 3574 insertions, 413 deletions
diff --git a/.overcommit.yml b/.overcommit.yml
index d26d3de52..9212c33b3 100644
--- a/.overcommit.yml
+++ b/.overcommit.yml
@@ -29,7 +29,8 @@ PreCommit:
command: ['luacheck', 'lualib', 'src/plugins/lua']
ClangFormat:
enabled: true
- command: ['git', 'clang-format', '--diff']
+ command: ['sh', '-c', 'git clang-format --diff --quiet || (echo "Running clang-format to fix issues..." && git clang-format && git add -u && echo "Files formatted and staged.")']
+ on_warn: fail
#PostCheckout:
# ALL: # Special hook name that customizes all hooks of this type
# quiet: true # Change all post-checkout hooks to only display output on failure
diff --git a/conf/statistic.conf b/conf/statistic.conf
index 36d418935..3ba460ff3 100644
--- a/conf/statistic.conf
+++ b/conf/statistic.conf
@@ -35,11 +35,11 @@ classifier "bayes" {
statfile {
symbol = "BAYES_HAM";
- spam = false;
+ class = "ham";
}
statfile {
symbol = "BAYES_SPAM";
- spam = true;
+ class = "spam";
}
learn_condition = 'return require("lua_bayes_learn").can_learn';
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua
index 782e6fc47..a7af80bf1 100644
--- a/lualib/lua_bayes_redis.lua
+++ b/lualib/lua_bayes_redis.lua
@@ -25,27 +25,44 @@ local ucl = require "ucl"
local N = "bayes"
local function gen_classify_functor(redis_params, classify_script_id)
- return function(task, expanded_key, id, is_spam, stat_tokens, callback)
-
+ return function(task, expanded_key, id, class_labels, stat_tokens, callback)
local function classify_redis_cb(err, data)
lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
if err then
callback(task, false, err)
else
- callback(task, true, data[1], data[2], data[3], data[4])
+ -- Pass the raw data table to the C++ callback for processing
+ -- The C++ callback will handle both binary and multi-class formats
+ callback(task, true, data)
+ end
+ end
+
+ -- Determine class labels to send to Redis script
+ local script_class_labels
+ if type(class_labels) == "table" then
+ -- Use simple comma-separated string instead of messagepack
+ script_class_labels = "TABLE:" .. table.concat(class_labels, ",")
+ else
+ -- Single class label or boolean compatibility
+ if class_labels == true or class_labels == "true" then
+ script_class_labels = "S" -- spam
+ elseif class_labels == false or class_labels == "false" then
+ script_class_labels = "H" -- ham
+ else
+ script_class_labels = class_labels -- string class label
end
end
lua_redis.exec_redis_script(classify_script_id,
{ task = task, is_write = false, key = expanded_key },
- classify_redis_cb, { expanded_key, stat_tokens })
+ classify_redis_cb, { expanded_key, script_class_labels, stat_tokens })
end
end
local function gen_learn_functor(redis_params, learn_script_id)
- return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
+ return function(task, expanded_key, id, class_label, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
local function learn_redis_cb(err, data)
- lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
+ lua_util.debugm(N, task, 'learn redis cb: %s, %s for class %s', err, data, class_label)
if err then
callback(task, false, err)
else
@@ -53,17 +70,24 @@ local function gen_learn_functor(redis_params, learn_script_id)
end
end
+ -- Convert class_label for backward compatibility
+ local script_class_label = class_label
+ if class_label == true or class_label == "true" then
+ script_class_label = "S" -- spam
+ elseif class_label == false or class_label == "false" then
+ script_class_label = "H" -- ham
+ end
+
if maybe_text_tokens then
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = expanded_key },
learn_redis_cb,
- { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
+ { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
else
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = expanded_key },
- learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
+ learn_redis_cb, { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens })
end
-
end
end
@@ -112,8 +136,7 @@ end
--- @param classifier_ucl ucl of the classifier config
--- @param statfile_ucl ucl of the statfile config
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
-exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
-
+exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, class_label, ev_base, stat_periodic_cb)
local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
if not redis_params then
@@ -137,7 +160,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
if ev_base then
rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
-
local function stat_redis_cb(err, data)
lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
@@ -162,11 +184,22 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
end
end
+ -- Convert class_label to learn key
+ local learn_key
+ if class_label == true or class_label == "true" or class_label == "S" then
+ learn_key = "learns_spam"
+ elseif class_label == false or class_label == "false" or class_label == "H" then
+ learn_key = "learns_ham"
+ else
+ -- For other class labels, use learns_<class_label>
+ learn_key = "learns_" .. string.lower(tostring(class_label))
+ end
+
lua_redis.exec_redis_script(stat_script_id,
{ ev_base = ev_base, cfg = cfg, is_write = false },
stat_redis_cb, { tostring(cursor),
symbol,
- is_spam and "learns_spam" or "learns_ham",
+ learn_key,
tostring(max_users) })
return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
end)
@@ -178,7 +211,6 @@ end
local function gen_cache_check_functor(redis_params, check_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
return function(task, cache_id, callback)
-
local function classify_redis_cb(err, data)
lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
if err then
@@ -201,17 +233,16 @@ end
local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
- return function(task, cache_id, is_spam)
+ return function(task, cache_id, class_name, class_id)
local function learn_redis_cb(err, data)
lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
end
- lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
+ lua_util.debugm(N, task, 'try to learn cache: %s as %s (id=%s)', cache_id, class_name, class_id)
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = cache_id },
learn_redis_cb,
- { cache_id, is_spam and "1" or "0", packed_conf })
-
+ { cache_id, tostring(class_id), packed_conf })
end
end
diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua
index 7d44a73ef..a7c9ac443 100644
--- a/lualib/redis_scripts/bayes_cache_learn.lua
+++ b/lualib/redis_scripts/bayes_cache_learn.lua
@@ -1,12 +1,15 @@
--- Lua script to perform cache checking for bayes classification
+-- Lua script to perform cache checking for bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - cache id
--- key2 - is spam (1 or 0)
+-- key2 - class_id (numeric hash of class name, computed by C side)
-- key3 - configuration table in message pack
local cache_id = KEYS[1]
-local is_spam = KEYS[2]
+local class_id = KEYS[2]
local conf = cmsgpack.unpack(KEYS[3])
+
+-- Use class_id directly as cache value
+local cache_value = tostring(class_id)
cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
-- Try each prefix that is in Redis (as some other instance might have set it)
@@ -15,8 +18,8 @@ for i = 0, conf.cache_max_keys do
local have = redis.call('HGET', prefix, cache_id)
if have then
- -- Already in cache, but is_spam changes when relearning
- redis.call('HSET', prefix, cache_id, is_spam)
+ -- Already in cache, but cache_value changes when relearning
+ redis.call('HSET', prefix, cache_id, cache_value)
return false
end
end
@@ -30,7 +33,7 @@ for i = 0, conf.cache_max_keys do
if count < lim then
-- We can add it to this prefix
- redis.call('HSET', prefix, cache_id, is_spam)
+ redis.call('HSET', prefix, cache_id, cache_value)
added = true
end
end
@@ -46,7 +49,7 @@ if not added then
if exists then
if not expired then
redis.call('DEL', prefix)
- redis.call('HSET', prefix, cache_id, is_spam)
+ redis.call('HSET', prefix, cache_id, cache_value)
-- Do not expire anything else
expired = true
diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua
index e94f645fd..d6132e631 100644
--- a/lualib/redis_scripts/bayes_classify.lua
+++ b/lualib/redis_scripts/bayes_classify.lua
@@ -1,37 +1,68 @@
--- Lua script to perform bayes classification
+-- Lua script to perform bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - set of tokens encoded in messagepack array of strings
+-- key2 - class labels: table of all class labels as "TABLE:label1,label2,..."
+-- key3 - set of tokens encoded in messagepack array of strings
local prefix = KEYS[1]
-local output_spam = {}
-local output_ham = {}
+local class_labels_arg = KEYS[2]
+local input_tokens = cmsgpack.unpack(KEYS[3])
-local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0
-local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0
+-- Parse class labels (always expect TABLE: format)
+local class_labels = {}
+if string.match(class_labels_arg, "^TABLE:") then
+ local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix
+ for label in string.gmatch(labels_str, "([^,]+)") do
+ table.insert(class_labels, label)
+ end
+else
+ -- Legacy single class - convert to array
+ class_labels = { class_labels_arg }
+end
--- Output is a set of pairs (token_index, token_count), tokens that are not
--- found are not filled.
--- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
+-- Get learned counts for all classes (ordered)
+local learned_counts = {}
+for _, label in ipairs(class_labels) do
+ local key = 'learns_' .. string.lower(label)
+ -- Handle legacy keys for backward compatibility
+ if label == 'H' then
+ key = 'learns_ham'
+ elseif label == 'S' then
+ key = 'learns_spam'
+ end
+ table.insert(learned_counts, tonumber(redis.call('HGET', prefix, key)) or 0)
+end
-if learned_ham > 0 and learned_spam > 0 then
- local input_tokens = cmsgpack.unpack(KEYS[2])
- for i, token in ipairs(input_tokens) do
- local token_data = redis.call('HMGET', token, 'H', 'S')
+-- Get token data for all classes (ordered)
+local token_results = {}
+for i, _ in ipairs(class_labels) do
+ token_results[i] = {}
+end
- if token_data then
- local ham_count = token_data[1]
- local spam_count = token_data[2]
+-- Check if we have any learning data
+local has_learns = false
+for _, count in ipairs(learned_counts) do
+ if count > 0 then
+ has_learns = true
+ break
+ end
+end
- if ham_count then
- table.insert(output_ham, { i, tonumber(ham_count) })
- end
+if has_learns then
+ -- Process each token
+ for i, token in ipairs(input_tokens) do
+ local token_data = redis.call('HMGET', token, unpack(class_labels))
- if spam_count then
- table.insert(output_spam, { i, tonumber(spam_count) })
+ if token_data then
+ for j, _ in ipairs(class_labels) do
+ local count = token_data[j]
+ if count and tonumber(count) > 0 then
+ table.insert(token_results[j], { i, tonumber(count) })
+ end
end
end
end
end
-return { learned_ham, learned_spam, output_ham, output_spam } \ No newline at end of file
+-- Always return ordered arrays: [learned_counts_array, token_results_array]
+return { learned_counts, token_results }
diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua
index 5456165b6..ebc798fe0 100644
--- a/lualib/redis_scripts/bayes_learn.lua
+++ b/lualib/redis_scripts/bayes_learn.lua
@@ -1,14 +1,14 @@
--- Lua script to perform bayes learning
+-- Lua script to perform bayes learning (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - boolean is_spam
+-- key2 - class label string (e.g. "S", "H", "T")
-- key3 - string symbol
-- key4 - boolean is_unlearn
-- key5 - set of tokens encoded in messagepack array of strings
-- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`)
local prefix = KEYS[1]
-local is_spam = KEYS[2] == 'true' and true or false
+local class_label = KEYS[2]
local symbol = KEYS[3]
local is_unlearn = KEYS[4] == 'true' and true or false
local input_tokens = cmsgpack.unpack(KEYS[5])
@@ -18,15 +18,47 @@ if KEYS[6] then
text_tokens = cmsgpack.unpack(KEYS[6])
end
-local hash_key = is_spam and 'S' or 'H'
-local learned_key = is_spam and 'learns_spam' or 'learns_ham'
+-- Handle backward compatibility for boolean values
+if class_label == 'true' then
+ class_label = 'S' -- spam
+elseif class_label == 'false' then
+ class_label = 'H' -- ham
+end
+
+local hash_key = class_label
+local learned_key = 'learns_' .. string.lower(class_label)
+
+-- Handle legacy keys for backward compatibility
+if class_label == 'S' then
+ learned_key = 'learns_spam'
+elseif class_label == 'H' then
+ learned_key = 'learns_ham'
+end
redis.call('SADD', symbol .. '_keys', prefix)
redis.call('HSET', prefix, 'version', '2') -- new schema
-redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count
+
+-- Update learned count, but prevent it from going negative
+if is_unlearn then
+ local current_count = tonumber(redis.call('HGET', prefix, learned_key)) or 0
+ if current_count > 0 then
+ redis.call('HINCRBY', prefix, learned_key, -1)
+ end
+else
+ redis.call('HINCRBY', prefix, learned_key, 1)
+end
for i, token in ipairs(input_tokens) do
- redis.call('HINCRBY', token, hash_key, is_unlearn and -1 or 1)
+ -- Update token count, but prevent it from going negative
+ if is_unlearn then
+ local current_token_count = tonumber(redis.call('HGET', token, hash_key)) or 0
+ if current_token_count > 0 then
+ redis.call('HINCRBY', token, hash_key, -1)
+ end
+ else
+ redis.call('HINCRBY', token, hash_key, 1)
+ end
+
if text_tokens then
local tok1 = text_tokens[i * 2 - 1]
local tok2 = text_tokens[i * 2]
@@ -38,7 +70,14 @@ for i, token in ipairs(input_tokens) do
redis.call('HSET', token, 'tokens', tok1)
end
- redis.call('ZINCRBY', prefix .. '_z', is_unlearn and -1 or 1, token)
+ if is_unlearn then
+ local current_z_score = tonumber(redis.call('ZSCORE', prefix .. '_z', token)) or 0
+ if current_z_score > 0 then
+ redis.call('ZINCRBY', prefix .. '_z', -1, token)
+ end
+ else
+ redis.call('ZINCRBY', prefix .. '_z', 1, token)
+ end
end
end
end
diff --git a/src/client/rspamc.cxx b/src/client/rspamc.cxx
index 404359877..1dc48faae 100644
--- a/src/client/rspamc.cxx
+++ b/src/client/rspamc.cxx
@@ -59,6 +59,7 @@ static const char *user = nullptr;
static const char *helo = nullptr;
static const char *hostname = nullptr;
static const char *classifier = nullptr;
+static const char *learn_class_name = nullptr;
static const char *local_addr = nullptr;
static const char *execute = nullptr;
static const char *sort = nullptr;
@@ -90,6 +91,9 @@ static gboolean skip_attachments = FALSE;
static const char *pubkey = nullptr;
static const char *user_agent = "rspamc";
static const char *files_list = nullptr;
+static const char *queue_id = nullptr;
+static const char *log_tag = nullptr;
+static std::string settings;
std::vector<GPid> children;
static GPatternSpec **exclude_compiled = nullptr;
@@ -102,6 +106,11 @@ static gboolean rspamc_password_callback(const char *option_name,
gpointer data,
GError **error);
+static gboolean rspamc_settings_callback(const char *option_name,
+ const char *value,
+ gpointer data,
+ GError **error);
+
static GOptionEntry entries[] =
{
{"connect", 'h', 0, G_OPTION_ARG_STRING, &connect_str,
@@ -182,6 +191,12 @@ static GOptionEntry entries[] =
"Use specific User-Agent instead of \"rspamc\"", nullptr},
{"files-list", '\0', 0, G_OPTION_ARG_FILENAME, &files_list,
"Read one or more newline separated filenames to scan from file", nullptr},
+ {"queue-id", '\0', 0, G_OPTION_ARG_STRING, &queue_id,
+ "Set Queue-ID header for the request", nullptr},
+ {"log-tag", '\0', 0, G_OPTION_ARG_STRING, &log_tag,
+ "Set Log-Tag header for the request", nullptr},
+ {"settings", '\0', 0, G_OPTION_ARG_CALLBACK, (void *) &rspamc_settings_callback,
+ "Set Settings header as JSON/UCL for the request", nullptr},
{nullptr, 0, 0, G_OPTION_ARG_NONE, nullptr, nullptr, nullptr}};
static void rspamc_symbols_output(FILE *out, ucl_object_t *obj);
@@ -198,6 +213,7 @@ enum rspamc_command_type {
RSPAMC_COMMAND_SYMBOLS,
RSPAMC_COMMAND_LEARN_SPAM,
RSPAMC_COMMAND_LEARN_HAM,
+ RSPAMC_COMMAND_LEARN_CLASS,
RSPAMC_COMMAND_FUZZY_ADD,
RSPAMC_COMMAND_FUZZY_DEL,
RSPAMC_COMMAND_FUZZY_DELHASH,
@@ -250,6 +266,15 @@ static const constexpr auto rspamc_commands = rspamd::array_of(
.need_input = TRUE,
.command_output_func = nullptr},
rspamc_command{
+ .cmd = RSPAMC_COMMAND_LEARN_CLASS,
+ .name = "learn_class",
+ .path = "learnclass",
+ .description = "learn message as class",
+ .is_controller = TRUE,
+ .is_privileged = TRUE,
+ .need_input = TRUE,
+ .command_output_func = nullptr},
+ rspamc_command{
.cmd = RSPAMC_COMMAND_FUZZY_ADD,
.name = "fuzzy_add",
.path = "fuzzyadd",
@@ -527,8 +552,7 @@ rspamc_password_callback(const char *option_name,
auto *map = (char *) locked_mmap.value().get_map();
value_view = std::string_view{map, locked_mmap->get_size()};
auto right = value_view.end() - 1;
- for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right)
- ;
+ for (; right > value_view.cbegin() && g_ascii_isspace(*right); --right);
std::string_view str{value_view.begin(), static_cast<size_t>(right - value_view.begin()) + 1};
processed_passwd.assign(std::begin(str), std::end(str));
processed_passwd.push_back('\0'); /* Null-terminate for C part */
@@ -557,6 +581,46 @@ rspamc_password_callback(const char *option_name,
return TRUE;
}
+static gboolean
+rspamc_settings_callback(const char *option_name,
+ const char *value,
+ gpointer data,
+ GError **error)
+{
+ if (value == nullptr) {
+ g_set_error(error, G_OPTION_ERROR, G_OPTION_ERROR_BAD_VALUE,
+ "Settings parameter cannot be empty");
+ return FALSE;
+ }
+
+ // Parse the settings string using UCL to validate it
+ struct ucl_parser *parser = ucl_parser_new(UCL_PARSER_KEY_LOWERCASE);
+ if (!ucl_parser_add_string(parser, value, strlen(value))) {
+ auto *ucl_error = ucl_parser_get_error(parser);
+ g_set_error(error, G_OPTION_ERROR, G_OPTION_ERROR_BAD_VALUE,
+ "Invalid JSON/UCL in settings: %s", ucl_error);
+ ucl_parser_free(parser);
+ return FALSE;
+ }
+
+ // Get the parsed object and validate it
+ auto *obj = ucl_parser_get_object(parser);
+ if (obj == nullptr) {
+ g_set_error(error, G_OPTION_ERROR, G_OPTION_ERROR_BAD_VALUE,
+ "Failed to parse settings as JSON/UCL");
+ ucl_parser_free(parser);
+ return FALSE;
+ }
+
+ // Store the validated settings string
+ settings = value;
+
+ ucl_object_unref(obj);
+ ucl_parser_free(parser);
+
+ return TRUE;
+}
+
/*
* Parse command line
*/
@@ -649,6 +713,7 @@ check_rspamc_command(const char *cmd) -> std::optional<rspamc_command>
{"report", RSPAMC_COMMAND_SYMBOLS},
{"learn_spam", RSPAMC_COMMAND_LEARN_SPAM},
{"learn_ham", RSPAMC_COMMAND_LEARN_HAM},
+ {"learn_class", RSPAMC_COMMAND_LEARN_CLASS},
{"fuzzy_add", RSPAMC_COMMAND_FUZZY_ADD},
{"fuzzy_del", RSPAMC_COMMAND_FUZZY_DEL},
{"fuzzy_delhash", RSPAMC_COMMAND_FUZZY_DELHASH},
@@ -659,10 +724,33 @@ check_rspamc_command(const char *cmd) -> std::optional<rspamc_command>
});
std::string cmd_lc = rspamd_string_tolower(cmd);
+
+ // Handle learn_class:classname syntax
+ if (cmd_lc.find("learn_class:") == 0) {
+ auto colon_pos = cmd_lc.find(':');
+ if (colon_pos != std::string::npos && colon_pos + 1 < cmd_lc.length()) {
+ auto class_name = cmd_lc.substr(colon_pos + 1);
+ // Store class name globally for later use
+ learn_class_name = g_strdup(class_name.c_str());
+ // Return the learn_class command
+ auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
+ return item.cmd == RSPAMC_COMMAND_LEARN_CLASS;
+ });
+ if (elt_it != std::end(rspamc_commands)) {
+ return *elt_it;
+ }
+ }
+ return std::nullopt;
+ }
+
auto ct = rspamd::find_map(str_map, std::string_view{cmd_lc});
+ if (!ct.has_value()) {
+ return std::nullopt;
+ }
+
auto elt_it = std::find_if(rspamc_commands.begin(), rspamc_commands.end(), [&](const auto &item) {
- return item.cmd == ct;
+ return item.cmd == ct.value();
});
if (elt_it != std::end(rspamc_commands)) {
@@ -799,6 +887,10 @@ add_options(GQueue *opts)
add_client_header(opts, "Classifier", classifier);
}
+ if (learn_class_name) {
+ add_client_header(opts, "Class", learn_class_name);
+ }
+
if (weight != 0) {
auto nstr = fmt::format("{}", weight);
add_client_header(opts, "Weight", nstr.c_str());
@@ -852,6 +944,18 @@ add_options(GQueue *opts)
hdr++;
}
+ if (queue_id != nullptr) {
+ add_client_header(opts, "Queue-Id", queue_id);
+ }
+
+ if (log_tag != nullptr) {
+ add_client_header(opts, "Log-Tag", log_tag);
+ }
+
+ if (!settings.empty()) {
+ add_client_header(opts, "Settings", settings.c_str());
+ }
+
if (!flagbuf.empty()) {
if (flagbuf.back() == ',') {
flagbuf.pop_back();
@@ -1918,7 +2022,7 @@ rspamc_client_cb(struct rspamd_client_connection *conn,
if (raw_body) {
/* We can also output the resulting json */
- rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t)(rawlen - bodylen)});
+ rspamc_print(out, "{}\n", std::string_view{raw_body, (std::size_t) (rawlen - bodylen)});
}
}
}
@@ -1950,7 +2054,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd,
p = strrchr(connect_str, ']');
if (p != nullptr) {
- hostbuf.assign(connect_str + 1, (std::size_t)(p - connect_str - 1));
+ hostbuf.assign(connect_str + 1, (std::size_t) (p - connect_str - 1));
p++;
}
else {
@@ -1965,7 +2069,7 @@ rspamc_process_input(struct ev_loop *ev_base, const struct rspamc_command &cmd,
if (hostbuf.empty()) {
if (p != nullptr) {
- hostbuf.assign(connect_str, (std::size_t)(p - connect_str));
+ hostbuf.assign(connect_str, (std::size_t) (p - connect_str));
}
else {
hostbuf.assign(connect_str);
diff --git a/src/controller.c b/src/controller.c
index 0550ba6b8..6e0e4cac1 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -53,6 +53,7 @@
#define PATH_HISTORY_RESET "/historyreset"
#define PATH_LEARN_SPAM "/learnspam"
#define PATH_LEARN_HAM "/learnham"
+#define PATH_LEARN_CLASS "/learnclass"
#define PATH_METRICS "/metrics"
#define PATH_READY "/ready"
#define PATH_SAVE_ACTIONS "/saveactions"
@@ -2126,6 +2127,7 @@ rspamd_controller_handle_learn_common(
struct rspamd_controller_worker_ctx *ctx;
struct rspamd_task *task;
const rspamd_ftok_t *cl_header;
+ const char *class_name;
ctx = session->ctx;
@@ -2167,7 +2169,9 @@ rspamd_controller_handle_learn_common(
goto end;
}
- rspamd_learn_task_spam(task, is_spam, session->classifier, NULL);
+ /* Use unified class-based learning approach */
+ class_name = is_spam ? "spam" : "ham";
+ rspamd_task_set_autolearn_class(task, class_name);
if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
msg_warn_session("<%s> message cannot be processed",
@@ -2212,6 +2216,96 @@ rspamd_controller_handle_learnham(
}
/*
+ * Learn class command handler:
+ * request: /learnclass
+ * headers: Password, Class
+ * input: plaintext data
+ * reply: json {"success":true} or {"error":"error message"}
+ */
+static int
+rspamd_controller_handle_learnclass(
+ struct rspamd_http_connection_entry *conn_ent,
+ struct rspamd_http_message *msg)
+{
+ struct rspamd_controller_session *session = conn_ent->ud;
+ struct rspamd_controller_worker_ctx *ctx;
+ struct rspamd_task *task;
+ const rspamd_ftok_t *cl_header, *class_header;
+ char *class_name = NULL;
+
+ ctx = session->ctx;
+
+ if (!rspamd_controller_check_password(conn_ent, session, msg, TRUE)) {
+ return 0;
+ }
+
+ if (rspamd_http_message_get_body(msg, NULL) == NULL) {
+ msg_err_session("got zero length body, cannot continue");
+ rspamd_controller_send_error(conn_ent,
+ 400,
+ "Empty body is not permitted");
+ return 0;
+ }
+
+ class_header = rspamd_http_message_find_header(msg, "Class");
+ if (!class_header) {
+ msg_err_session("missing Class header for multiclass learning");
+ rspamd_controller_send_error(conn_ent,
+ 400,
+ "Class header is required for multiclass learning");
+ return 0;
+ }
+
+ task = rspamd_task_new(session->ctx->worker, session->cfg, session->pool,
+ session->ctx->lang_det, ctx->event_loop, FALSE);
+
+ task->resolver = ctx->resolver;
+ task->s = rspamd_session_create(session->pool,
+ rspamd_controller_learn_fin_task,
+ NULL,
+ (event_finalizer_t) rspamd_task_free,
+ task);
+ task->fin_arg = conn_ent;
+ task->http_conn = rspamd_http_connection_ref(conn_ent->conn);
+ task->sock = -1;
+ session->task = task;
+
+ cl_header = rspamd_http_message_find_header(msg, "classifier");
+ if (cl_header) {
+ session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header);
+ }
+ else {
+ session->classifier = NULL;
+ }
+
+ if (!rspamd_task_load_message(task, msg, msg->body_buf.begin, msg->body_buf.len)) {
+ goto end;
+ }
+
+ /* Set multiclass learning flag and store class name */
+ class_name = rspamd_mempool_ftokdup(task->task_pool, class_header);
+ rspamd_task_set_autolearn_class(task, class_name);
+
+ if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
+ msg_warn_session("<%s> message cannot be processed",
+ MESSAGE_FIELD_CHECK(task, message_id));
+ goto end;
+ }
+
+end:
+ /* Set session spam flag for logging compatibility */
+ if (class_name) {
+ session->is_spam = (strcmp(class_name, "spam") == 0);
+ }
+ else {
+ session->is_spam = FALSE;
+ }
+ rspamd_session_pending(task->s);
+
+ return 0;
+}
+
+/*
* Scan command handler:
* request: /scan
* headers: Password
@@ -3292,7 +3386,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent,
rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods",
"POST, GET, OPTIONS");
rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers",
- "Classifier,Content-Type,Password,Map,Weight,Flag,Hash");
+ "Classifier,Class,Content-Type,Password,Map,Weight,Flag,Hash");
rspamd_http_connection_reset(conn_ent->conn);
rspamd_http_router_insert_headers(conn_ent->rt, rep);
rspamd_http_connection_write_message(conn_ent->conn,
@@ -3456,7 +3550,7 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en
*/
static int
rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent,
- struct rspamd_http_message *msg)
+ struct rspamd_http_message *msg)
{
struct rspamd_controller_session *session = conn_ent->ud;
struct rspamd_controller_worker_ctx *ctx = session->ctx;
@@ -4049,6 +4143,9 @@ start_controller_worker(struct rspamd_worker *worker)
PATH_LEARN_HAM,
rspamd_controller_handle_learnham);
rspamd_http_router_add_path(ctx->http,
+ PATH_LEARN_CLASS,
+ rspamd_controller_handle_learnclass);
+ rspamd_http_router_add_path(ctx->http,
PATH_METRICS,
rspamd_controller_handle_metrics);
rspamd_http_router_add_path(ctx->http,
diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h
index 36941da7a..355046cac 100644
--- a/src/libserver/cfg_file.h
+++ b/src/libserver/cfg_file.h
@@ -139,7 +139,10 @@ struct rspamd_statfile_config {
char *symbol; /**< symbol of statfile */
char *label; /**< label of this statfile */
ucl_object_t *opts; /**< other options */
- gboolean is_spam; /**< spam flag */
+ char *class_name; /**< class name for multi-class classification */
+ unsigned int class_index; /**< class index for O(1) lookup during classification */
+ gboolean is_spam; /**< DEPRECATED: spam flag - use class_name instead */
+ gboolean is_spam_converted; /**< TRUE if class_name was converted from is_spam flag */
struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration */
gpointer data; /**< opaque data */
};
@@ -182,6 +185,8 @@ struct rspamd_classifier_config {
double min_prob_strength; /**< use only tokens with probability in [0.5 - MPS, 0.5 + MPS] */
unsigned int min_learns; /**< minimum number of learns for each statfile */
unsigned int flags;
+ GHashTable *class_labels; /**< class_name -> backend_symbol mapping for multi-class */
+ GPtrArray *class_names; /**< ordered list of class names */
};
struct rspamd_worker_bind_conf {
@@ -621,12 +626,25 @@ void rspamd_config_insert_classify_symbols(struct rspamd_config *cfg);
*/
gboolean rspamd_config_check_statfiles(struct rspamd_classifier_config *cf);
-/*
- * Find classifier config by name
+/**
+ * Multi-class configuration helpers
+ */
+gboolean rspamd_config_parse_class_labels(const ucl_object_t *obj,
+ GHashTable **class_labels);
+
+gboolean rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf);
+
+gboolean rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf,
+ GError **err);
+
+const char *rspamd_config_get_class_label(struct rspamd_classifier_config *ccf,
+ const char *class_name);
+
+/**
+ * Find classifier by name
*/
struct rspamd_classifier_config *rspamd_config_find_classifier(
- struct rspamd_config *cfg,
- const char *name);
+ struct rspamd_config *cfg, const char *name);
void rspamd_ucl_add_conf_macros(struct ucl_parser *parser,
struct rspamd_config *cfg);
diff --git a/src/libserver/cfg_rcl.cxx b/src/libserver/cfg_rcl.cxx
index 0a48e8a4f..da5845917 100644
--- a/src/libserver/cfg_rcl.cxx
+++ b/src/libserver/cfg_rcl.cxx
@@ -1197,31 +1197,73 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj,
st->opts = (ucl_object_t *) obj;
st->clcf = ccf;
- const auto *val = ucl_object_lookup(obj, "spam");
- if (val == nullptr) {
+ /* Handle migration from old 'spam' field to new 'class' field */
+ const auto *class_val = ucl_object_lookup(obj, "class");
+ const auto *spam_val = ucl_object_lookup(obj, "spam");
+
+ if (class_val != nullptr && spam_val != nullptr) {
+ msg_warn_config("statfile %s has both 'class' and 'spam' fields, using 'class' field",
+ st->symbol);
+ }
+
+ if (class_val == nullptr && spam_val == nullptr) {
+ /* Neither field present, try to guess by symbol name */
msg_info_config(
- "statfile %s has no explicit 'spam' setting, trying to guess by symbol",
+ "statfile %s has no explicit 'class' or 'spam' setting, trying to guess by symbol",
st->symbol);
if (rspamd_substring_search_caseless(st->symbol,
strlen(st->symbol), "spam", 4) != -1) {
st->is_spam = TRUE;
+ st->class_name = rspamd_mempool_strdup(pool, "spam");
+ st->is_spam_converted = TRUE;
}
else if (rspamd_substring_search_caseless(st->symbol,
strlen(st->symbol), "ham", 3) != -1) {
st->is_spam = FALSE;
+ st->class_name = rspamd_mempool_strdup(pool, "ham");
+ st->is_spam_converted = TRUE;
}
else {
g_set_error(err,
CFG_RCL_ERROR,
EINVAL,
- "cannot guess spam setting from %s",
+ "cannot guess class setting from %s, please specify 'class' field",
st->symbol);
return FALSE;
}
- msg_info_config("guessed that statfile with symbol %s is %s",
- st->symbol,
- st->is_spam ? "spam" : "ham");
+ msg_info_config("guessed that statfile with symbol %s has class '%s'",
+ st->symbol, st->class_name);
+ }
+ else if (class_val == nullptr && spam_val != nullptr) {
+ /* Only spam field present - migrate to class */
+ msg_warn_config("statfile %s uses deprecated 'spam' field, please use 'class' instead",
+ st->symbol);
+ if (st->is_spam) {
+ st->class_name = rspamd_mempool_strdup(pool, "spam");
+ }
+ else {
+ st->class_name = rspamd_mempool_strdup(pool, "ham");
+ }
+ st->is_spam_converted = TRUE;
}
+ else if (class_val != nullptr && spam_val == nullptr) {
+ /* Only class field present - set is_spam for backward compatibility */
+ if (st->class_name != nullptr) {
+ if (strcmp(st->class_name, "spam") == 0) {
+ st->is_spam = TRUE;
+ }
+ else if (strcmp(st->class_name, "ham") == 0) {
+ st->is_spam = FALSE;
+ }
+ else {
+ /* For non-binary classes, default to not spam */
+ st->is_spam = FALSE;
+ }
+ msg_debug_config("statfile %s with class '%s' set is_spam=%s for compatibility",
+ st->symbol, st->class_name, st->is_spam ? "true" : "false");
+ }
+ }
+ /* If both fields are present, class takes precedence and was already parsed by the default parser */
return TRUE;
}
@@ -1229,6 +1271,31 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj,
}
static gboolean
+rspamd_rcl_class_labels_handler(rspamd_mempool_t *pool,
+ const ucl_object_t *obj,
+ const char *key,
+ gpointer ud,
+ struct rspamd_rcl_section *section,
+ GError **err)
+{
+ auto *ccf = static_cast<rspamd_classifier_config *>(ud);
+
+ if (obj->type != UCL_OBJECT) {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "class_labels must be an object");
+ return FALSE;
+ }
+
+ if (!rspamd_config_parse_class_labels(obj, &ccf->class_labels)) {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "invalid class_labels configuration");
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static gboolean
rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
const ucl_object_t *obj,
const char *key,
@@ -1301,6 +1368,22 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
}
}
}
+ else if (g_ascii_strcasecmp(st_key, "class_labels") == 0) {
+ /* Parse class_labels configuration directly */
+ if (ucl_object_type(val) != UCL_OBJECT) {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "class_labels must be an object");
+ ucl_object_iterate_free(it);
+ return FALSE;
+ }
+
+ if (!rspamd_config_parse_class_labels(val, &ccf->class_labels)) {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "invalid class_labels configuration");
+ ucl_object_iterate_free(it);
+ return FALSE;
+ }
+ }
}
}
@@ -1375,8 +1458,80 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool,
}
ccf->opts = (ucl_object_t *) obj;
+
+ /* Validate multi-class configuration */
+ GError *validation_err = nullptr;
+ if (!rspamd_config_validate_class_config(ccf, &validation_err)) {
+ if (validation_err) {
+ g_propagate_error(err, validation_err);
+ }
+ else {
+ g_set_error(err, CFG_RCL_ERROR, EINVAL,
+ "multi-class configuration validation failed for classifier '%s'",
+ ccf->name ? ccf->name : "unknown");
+ }
+ return FALSE;
+ }
+
cfg->classifiers = g_list_prepend(cfg->classifiers, ccf);
+ /* Populate class_names array from statfiles - only for explicit multiclass configs */
+ if (ccf->statfiles) {
+ GList *cur = ccf->statfiles;
+ gboolean has_explicit_classes = FALSE;
+
+ /* Check if any statfile uses explicit class declaration (not converted from is_spam) */
+ cur = ccf->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ msg_debug("checking statfile %s: class_name=%s, is_spam_converted=%s",
+ stcf->symbol, stcf->class_name ? stcf->class_name : "NULL",
+ stcf->is_spam_converted ? "true" : "false");
+ if (stcf->class_name && !stcf->is_spam_converted) {
+ has_explicit_classes = TRUE;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+
+ msg_debug("has_explicit_classes = %s", has_explicit_classes ? "true" : "false");
+
+ /* Only populate class_names for explicit multiclass configurations */
+ if (has_explicit_classes) {
+ msg_debug("populating class_names for multiclass configuration");
+ }
+ else {
+ msg_debug("skipping class_names population for binary configuration");
+ }
+
+ if (has_explicit_classes) {
+ ccf->class_names = g_ptr_array_new();
+
+ cur = ccf->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name) {
+ /* Check if class already exists */
+ bool found = false;
+ for (unsigned int i = 0; i < ccf->class_names->len; i++) {
+ if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) {
+ stcf->class_index = i; /* Store the index for O(1) lookup */
+ found = true;
+ break;
+ }
+ }
+
+ if (!found) {
+ /* Add new class */
+ stcf->class_index = ccf->class_names->len;
+ g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name));
+ }
+ }
+ cur = g_list_next(cur);
+ }
+ }
+ }
+
return TRUE;
}
@@ -2457,7 +2612,7 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections)
FALSE,
TRUE,
cfg->doc_strings,
- "CLassifier options");
+ "Classifier options");
/* Default classifier is 'bayes' for now */
sub->default_key = "bayes";
@@ -2476,7 +2631,7 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections)
rspamd_rcl_add_default_handler(sub,
"min_prob_strength",
rspamd_rcl_parse_struct_double,
- G_STRUCT_OFFSET(struct rspamd_classifier_config, min_token_hits),
+ G_STRUCT_OFFSET(struct rspamd_classifier_config, min_prob_strength),
0,
"Use only tokens with probability in [0.5 - MPS, 0.5 + MPS]");
rspamd_rcl_add_default_handler(sub,
@@ -2505,6 +2660,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections)
"Name of classifier");
/*
+ * Multi-class configuration
+ */
+ rspamd_rcl_add_section_doc(&top, sub,
+ "class_labels", nullptr,
+ rspamd_rcl_class_labels_handler,
+ UCL_OBJECT,
+ FALSE,
+ TRUE,
+ sub->doc_ref,
+ "Class to backend label mapping for multi-class classification");
+
+ /*
* Statfile defaults
*/
auto *ssub = rspamd_rcl_add_section_doc(&top, sub,
@@ -2522,11 +2689,17 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections)
0,
"Statfile unique label");
rspamd_rcl_add_default_handler(ssub,
+ "class",
+ rspamd_rcl_parse_struct_string,
+ G_STRUCT_OFFSET(struct rspamd_statfile_config, class_name),
+ 0,
+ "Class name for multi-class classification");
+ rspamd_rcl_add_default_handler(ssub,
"spam",
rspamd_rcl_parse_struct_boolean,
G_STRUCT_OFFSET(struct rspamd_statfile_config, is_spam),
0,
- "Sets if this statfile contains spam samples");
+ "DEPRECATED: Sets if this statfile contains spam samples (use 'class' instead)");
}
if (!(skip_sections && g_hash_table_lookup(skip_sections, "composite"))) {
diff --git a/src/libserver/cfg_utils.cxx b/src/libserver/cfg_utils.cxx
index c7bb20210..c22a9b877 100644
--- a/src/libserver/cfg_utils.cxx
+++ b/src/libserver/cfg_utils.cxx
@@ -3042,3 +3042,189 @@ rspamd_ip_is_local_cfg(struct rspamd_config *cfg,
return FALSE;
}
+
+gboolean
+rspamd_config_parse_class_labels(const ucl_object_t *obj, GHashTable **class_labels)
+{
+ const ucl_object_t *cur;
+ ucl_object_iter_t it = nullptr;
+
+ if (!obj || ucl_object_type(obj) != UCL_OBJECT) {
+ return FALSE;
+ }
+
+ if (*class_labels == nullptr) {
+ *class_labels = g_hash_table_new_full(g_str_hash, g_str_equal,
+ g_free, g_free);
+ }
+
+ while ((cur = ucl_object_iterate(obj, &it, true)) != nullptr) {
+ const char *class_name = ucl_object_key(cur);
+ const char *label = ucl_object_tostring(cur);
+
+ if (class_name && label) {
+ /* Validate class name: alphanumeric + underscore, max 32 chars */
+ if (strlen(class_name) > 32) {
+ msg_err("class name '%s' is too long (max 32 characters)", class_name);
+ g_hash_table_destroy(*class_labels);
+ *class_labels = nullptr;
+ return FALSE;
+ }
+
+ for (const char *p = class_name; *p; p++) {
+ if (!g_ascii_isalnum(*p) && *p != '_') {
+ msg_err("class name '%s' contains invalid character '%c'", class_name, *p);
+ g_hash_table_destroy(*class_labels);
+ *class_labels = nullptr;
+ return FALSE;
+ }
+ }
+
+ /* Validate label uniqueness */
+ if (g_hash_table_lookup(*class_labels, label)) {
+ msg_err("backend label '%s' is used by multiple classes", label);
+ g_hash_table_destroy(*class_labels);
+ *class_labels = nullptr;
+ return FALSE;
+ }
+ }
+
+ g_hash_table_insert(*class_labels, g_strdup(class_name), g_strdup(label));
+ }
+
+ return g_hash_table_size(*class_labels) > 0;
+}
+
+gboolean
+rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf)
+{
+ if (stcf->class_name != nullptr) {
+ /* Already migrated or using new format */
+ return TRUE;
+ }
+
+ if (stcf->is_spam) {
+ stcf->class_name = g_strdup("spam");
+ msg_info("migrated statfile '%s' from is_spam=true to class='spam'",
+ stcf->symbol ? stcf->symbol : "unknown");
+ }
+ else {
+ stcf->class_name = g_strdup("ham");
+ msg_info("migrated statfile '%s' from is_spam=false to class='ham'",
+ stcf->symbol ? stcf->symbol : "unknown");
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, GError **err)
+{
+ GList *cur;
+ GHashTable *seen_classes = nullptr;
+ struct rspamd_statfile_config *stcf;
+ unsigned int class_count = 0;
+
+ if (!ccf || !ccf->statfiles) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "classifier has no statfiles defined");
+ return FALSE;
+ }
+
+ seen_classes = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, nullptr);
+
+ /* Iterate through statfiles and collect classes */
+ cur = ccf->statfiles;
+ while (cur) {
+ stcf = (struct rspamd_statfile_config *) cur->data;
+
+ /* Migrate binary config if needed */
+ if (!rspamd_config_migrate_binary_config(stcf)) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "failed to migrate binary config for statfile '%s'",
+ stcf->symbol ? stcf->symbol : "unknown");
+ g_hash_table_destroy(seen_classes);
+ return FALSE;
+ }
+
+ /* Check class name */
+ if (!stcf->class_name || strlen(stcf->class_name) == 0) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "statfile '%s' has no class defined",
+ stcf->symbol ? stcf->symbol : "unknown");
+ g_hash_table_destroy(seen_classes);
+ return FALSE;
+ }
+
+ /* Track unique classes */
+ if (!g_hash_table_contains(seen_classes, stcf->class_name)) {
+ g_hash_table_insert(seen_classes, g_strdup(stcf->class_name), GINT_TO_POINTER(1));
+ class_count++;
+ }
+
+ cur = g_list_next(cur);
+ }
+
+ /* Validate class count */
+ if (class_count < 2) {
+ g_set_error(err, g_quark_from_static_string("config"), 1,
+ "classifier must have at least 2 classes, found %ud", class_count);
+ g_hash_table_destroy(seen_classes);
+ return FALSE;
+ }
+
+ if (class_count > 20) {
+ msg_warn("classifier has %ud classes, performance may be degraded above 20 classes",
+ class_count);
+ }
+
+ /* Initialize classifier class tracking - only for explicit multiclass configurations */
+ gboolean has_explicit_classes = FALSE;
+
+ /* Check if any statfile uses explicit class declaration (not converted from is_spam) */
+ cur = ccf->statfiles;
+ while (cur) {
+ stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && !stcf->is_spam_converted) {
+ has_explicit_classes = TRUE;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+
+ /* Only populate class_names for explicit multiclass configurations */
+ if (has_explicit_classes) {
+ if (ccf->class_names) {
+ g_ptr_array_unref(ccf->class_names);
+ }
+ ccf->class_names = g_ptr_array_new_with_free_func(g_free);
+
+ /* Populate class names array */
+ GHashTableIter iter;
+ gpointer key, value;
+ g_hash_table_iter_init(&iter, seen_classes);
+ while (g_hash_table_iter_next(&iter, &key, &value)) {
+ g_ptr_array_add(ccf->class_names, g_strdup((const char *) key));
+ }
+ }
+ else {
+ /* Binary configuration - ensure class_names is NULL */
+ if (ccf->class_names) {
+ g_ptr_array_unref(ccf->class_names);
+ ccf->class_names = nullptr;
+ }
+ }
+
+ g_hash_table_destroy(seen_classes);
+ return TRUE;
+}
+
+const char *
+rspamd_config_get_class_label(struct rspamd_classifier_config *ccf, const char *class_name)
+{
+ if (!ccf || !ccf->class_labels || !class_name) {
+ return nullptr;
+ }
+
+ return (const char *) g_hash_table_lookup(ccf->class_labels, class_name);
+}
diff --git a/src/libserver/task.c b/src/libserver/task.c
index 9f5b1f00a..f655ab11b 100644
--- a/src/libserver/task.c
+++ b/src/libserver/task.c
@@ -730,7 +730,7 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages)
if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) &&
!RSPAMD_TASK_IS_EMPTY(task) &&
- !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) {
+ !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS))) {
rspamd_stat_check_autolearn(task);
}
break;
@@ -738,12 +738,32 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages)
case RSPAMD_TASK_STAGE_LEARN:
case RSPAMD_TASK_STAGE_LEARN_PRE:
case RSPAMD_TASK_STAGE_LEARN_POST:
- if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM)) {
+ if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS)) {
if (task->err == NULL) {
- if (!rspamd_stat_learn(task,
- task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
- task->cfg->lua_state, task->classifier,
- st, &stat_error)) {
+ gboolean learn_result = FALSE;
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ /* Multi-class learning */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ learn_result = rspamd_stat_learn_class(task, autolearn_class,
+ task->cfg->lua_state, task->classifier,
+ st, &stat_error);
+ }
+ else {
+ g_set_error(&stat_error, g_quark_from_static_string("stat"), 500,
+ "No autolearn class specified for multi-class learning");
+ }
+ }
+ else {
+ /* Legacy binary learning */
+ learn_result = rspamd_stat_learn(task,
+ task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
+ task->cfg->lua_state, task->classifier,
+ st, &stat_error);
+ }
+
+ if (!learn_result) {
if (stat_error == NULL) {
g_set_error(&stat_error,
@@ -922,15 +942,14 @@ rspamd_learn_task_spam(struct rspamd_task *task,
const char *classifier,
GError **err)
{
+ /* Use unified class-based approach internally */
+ const char *class_name = is_spam ? "spam" : "ham";
+
/* Disable learn auto flag to avoid bad learn codes */
task->flags &= ~RSPAMD_TASK_FLAG_LEARN_AUTO;
- if (is_spam) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
- }
+ /* Use the unified class-based learning approach */
+ rspamd_task_set_autolearn_class(task, class_name);
task->classifier = classifier;
diff --git a/src/libserver/task.h b/src/libserver/task.h
index 1c1778fee..a1742e160 100644
--- a/src/libserver/task.h
+++ b/src/libserver/task.h
@@ -104,9 +104,9 @@ enum rspamd_task_stage {
#define RSPAMD_TASK_FLAG_LEARN_SPAM (1u << 12u)
#define RSPAMD_TASK_FLAG_LEARN_HAM (1u << 13u)
#define RSPAMD_TASK_FLAG_LEARN_AUTO (1u << 14u)
+#define RSPAMD_TASK_FLAG_LEARN_CLASS (1u << 25u)
#define RSPAMD_TASK_FLAG_BROKEN_HEADERS (1u << 15u)
-#define RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS (1u << 16u)
-#define RSPAMD_TASK_FLAG_HAS_HAM_TOKENS (1u << 17u)
+/* Removed RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS and RSPAMD_TASK_FLAG_HAS_HAM_TOKENS - not needed in multi-class */
#define RSPAMD_TASK_FLAG_EMPTY (1u << 18u)
#define RSPAMD_TASK_FLAG_PROFILE (1u << 19u)
#define RSPAMD_TASK_FLAG_GREYLISTED (1u << 20u)
@@ -114,7 +114,7 @@ enum rspamd_task_stage {
#define RSPAMD_TASK_FLAG_SSL (1u << 22u)
#define RSPAMD_TASK_FLAG_BAD_UNICODE (1u << 23u)
#define RSPAMD_TASK_FLAG_MESSAGE_REWRITE (1u << 24u)
-#define RSPAMD_TASK_FLAG_MAX_SHIFT (24u)
+#define RSPAMD_TASK_FLAG_MAX_SHIFT (25u)
/* Request has been done by a local client */
#define RSPAMD_TASK_PROTOCOL_FLAG_LOCAL_CLIENT (1u << 1u)
diff --git a/src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md b/src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md
new file mode 100644
index 000000000..dc8352374
--- /dev/null
+++ b/src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md
@@ -0,0 +1,451 @@
+# Rspamd Multiclass Bayes Architecture
+
+## Overview
+
+This document describes the complete data flow for the multiclass Bayes classification system in Rspamd, covering the interaction between C++ core, Lua scripts, Redis backend, and the classification pipeline.
+
+## High-Level Data Flow
+
+```
+[Task Processing] → [Tokenization] → [Redis Backend] → [Lua Scripts] → [Redis Scripts] → [Results] → [Classification]
+```
+
+## 1. Classification Pipeline Entry Point
+
+### 1.1 Task Processing Start
+
+```c
+// src/libstat/stat_process.c
+rspamd_stat_classify(struct rspamd_task *task, struct rspamd_config *cfg)
+```
+
+**Flow:**
+
+1. Task arrives for classification
+2. Iterates through configured classifiers
+3. For each classifier, calls `rspamd_stat_classifiers[i].classify_func()`
+4. For Bayes: calls `bayes_classify_multiclass()`
+
+### 1.2 Bayes Classification Entry
+
+```c
+// src/libstat/classifiers/bayes.c
+gboolean bayes_classify_multiclass(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+```
+
+**Key Steps:**
+
+1. Validates `ctx->cfg->class_names` array
+2. Sets up `bayes_task_closure` with class information
+3. **Calls Redis backend to fetch token data**
+4. Processes returned token values
+5. Calculates probabilities and inserts symbols
+
+## 2. Redis Backend Data Flow
+
+### 2.1 Backend Runtime Creation
+
+```cpp
+// src/libstat/backends/redis_backend.cxx
+gpointer rspamd_redis_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn, gpointer c, int _id)
+```
+
+**Runtime Structure:**
+
+```cpp
+template<class T>
+class redis_stat_runtime {
+ struct redis_stat_ctx *ctx; // Redis connection context
+ struct rspamd_task *task; // Current task
+ struct rspamd_statfile_config *stcf; // Statfile configuration
+ const char *redis_object_expanded; // Expanded key prefix
+ int id; // Statfile ID (critical!)
+ std::optional<std::map<int, T>> results; // Token index → value mapping
+};
+```
+
+**Critical Insight: Statfile ID Mapping**
+
+- Each statfile has a unique ID (`id`)
+- Token values are stored in `tok->values[id]` array
+- **The `id` must match exactly between runtime and statfile**
+
+### 2.2 Multiple Runtime Creation (Classification Mode)
+
+For multiclass classification, the system creates multiple runtimes:
+
+```cpp
+// For each statfile in classifier
+for (cur = stcf->clcf->statfiles; cur; cur = g_list_next(cur)) {
+ auto *other_stcf = (struct rspamd_statfile_config *) cur->data;
+
+ // Find correct statfile ID
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ int other_id = -1;
+ for (i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *st = g_ptr_array_index(st_ctx->statfiles, i);
+ if (st->stcf == other_stcf) {
+ other_id = st->id; // ← This is the critical mapping!
+ break;
+ }
+ }
+
+ // Create runtime with correct ID
+ auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ other_rt->id = other_id; // ← Must be set correctly!
+}
+```
+
+### 2.3 Token Processing Call
+
+```cpp
+gboolean rspamd_redis_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ int id, gpointer p)
+```
+
+**Flow:**
+
+1. Serializes tokens to MessagePack format
+2. Builds class labels string (e.g., "TABLE:H,S,N,T")
+3. Calls Lua function to execute Redis script
+4. Registers callback for async result processing
+
+## 3. Lua Script Layer
+
+### 3.1 Lua Function Entry Point
+
+```lua
+-- lualib/lua_bayes_redis.lua
+local function gen_classify_functor(redis_params, classify_script_id)
+ return function(task, expanded_key, id, stat_tokens, callback)
+ -- Executes Redis script via lua_redis
+ lua_redis.exec_redis_script(classify_script_id,
+ { task = task, is_write = false, key = expanded_key },
+ classify_redis_cb,
+ { expanded_key, class_labels, stat_tokens })
+ end
+end
+```
+
+**Key Components:**
+
+- `expanded_key`: Redis key prefix (e.g., "BAYES{user@domain}")
+- `class_labels`: "TABLE:H,S,N,T" format for multiclass
+- `stat_tokens`: MessagePack-encoded token array
+- `callback`: Function to handle Redis script results
+
+### 3.2 Class Labels Format
+
+**Critical Detail**: The class labels format determines Redis script behavior:
+
+```lua
+-- Binary mode (legacy)
+class_labels = "H" -- Single class
+
+-- Multiclass mode
+class_labels = "TABLE:H,S,N,T" -- Multiple classes with TABLE: prefix
+```
+
+## 4. Redis Script Execution
+
+### 4.1 Script Structure
+
+```lua
+-- lualib/redis_scripts/bayes_classify.lua
+local prefix = KEYS[1] -- "BAYES{user@domain}"
+local class_labels_arg = KEYS[2] -- "TABLE:H,S,N,T"
+local input_tokens = cmsgpack.unpack(KEYS[3]) -- [tok1, tok2, ...]
+```
+
+### 4.2 Class Label Parsing
+
+```lua
+local class_labels = {}
+if string.match(class_labels_arg, "^TABLE:") then
+ -- Multiclass mode
+ local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:"
+ for label in string.gmatch(labels_str, "([^,]+)") do
+ table.insert(class_labels, label) -- ["H", "S", "N", "T"]
+ end
+else
+ -- Binary mode (single label)
+ table.insert(class_labels, class_labels_arg)
+end
+```
+
+### 4.3 Redis Key Structure
+
+**Learning Counts:**
+
+```
+BAYES{user@domain}_H_learns → { learns: 1500 }
+BAYES{user@domain}_S_learns → { learns: 800 }
+BAYES{user@domain}_N_learns → { learns: 200 }
+BAYES{user@domain}_T_learns → { learns: 150 }
+```
+
+**Token Counts:**
+
+```
+BAYES{user@domain}_H_tokens → { token1: 45, token2: 12, ... }
+BAYES{user@domain}_S_tokens → { token1: 23, token2: 67, ... }
+BAYES{user@domain}_N_tokens → { token1: 5, token2: 8, ... }
+BAYES{user@domain}_T_tokens → { token1: 2, token2: 3, ... }
+```
+
+### 4.4 Token Lookup Process
+
+```lua
+-- Get learning counts for each class
+local learned_counts = {}
+for i, class_label in ipairs(class_labels) do
+ local learns_key = prefix .. "_" .. class_label .. "_learns"
+ learned_counts[i] = tonumber(redis.call('HGET', learns_key, 'learns') or '0')
+end
+
+-- Batch token lookup for all classes
+local pipe = redis.call('MULTI')
+for i, token in ipairs(input_tokens) do
+ for j, class_label in ipairs(class_labels) do
+ local token_key = prefix .. "_" .. class_label .. "_tokens"
+ redis.call('HGET', token_key, token)
+ end
+end
+local token_results = redis.call('EXEC')
+
+-- Parse results into ordered arrays
+local token_data = {}
+for j, class_label in ipairs(class_labels) do
+ token_data[j] = {} -- token_data[class_index][token_index] = count
+end
+
+local result_idx = 1
+for i, token in ipairs(input_tokens) do
+ for j, class_label in ipairs(class_labels) do
+ local count = tonumber(token_results[result_idx]) or 0
+ if count > 0 then
+ table.insert(token_data[j], {i, count}) -- {token_index, count}
+ end
+ result_idx = result_idx + 1
+ end
+end
+
+-- Return: [learned_counts, token_data]
+return {learned_counts, token_data}
+```
+
+### 4.5 Return Format
+
+**Redis Script Returns:**
+
+```lua
+{
+ [1] = {1500, 800, 200, 150}, -- learned_counts per class
+ [2] = { -- token_data per class
+ [1] = {{1,45}, {2,12}, ...}, -- Class H tokens: {token_idx, count}
+ [2] = {{1,23}, {2,67}, ...}, -- Class S tokens
+ [3] = {{1,5}, {2,8}, ...}, -- Class N tokens
+ [4] = {{1,2}, {2,3}, ...} -- Class T tokens
+ }
+}
+```
+
+## 5. Result Processing in C++
+
+### 5.1 Redis Callback Handler
+
+```cpp
+// src/libstat/backends/redis_backend.cxx
+static int rspamd_redis_classified(lua_State *L)
+{
+ auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+ bool result = lua_toboolean(L, 2);
+
+ if (result && lua_istable(L, 3)) {
+ // Process learned_counts (table index 1)
+ lua_rawgeti(L, 3, 1);
+ if (lua_istable(L, -1)) {
+ // Store learned counts (implementation detail)
+ }
+ lua_pop(L, 1);
+
+ // Process token_results (table index 2)
+ lua_rawgeti(L, 3, 2);
+ if (lua_istable(L, -1)) {
+ process_multiclass_token_results(L, rt, task);
+ }
+ lua_pop(L, 1);
+ }
+}
+```
+
+### 5.2 Token Results Processing
+
+```cpp
+static void process_multiclass_token_results(lua_State *L,
+ redis_stat_runtime<float> *rt,
+ struct rspamd_task *task)
+{
+ // L stack: token_results table at top
+ // Format: {[1] = {{1,45}, {2,12}}, [2] = {{1,23}, {2,67}}, ...}
+
+ if (rt->stcf->clcf && rt->stcf->clcf->statfiles) {
+ GList *cur = rt->stcf->clcf->statfiles;
+ int class_idx = 1;
+
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *)cur->data;
+
+ // Find correct statfile ID
+ int statfile_id = find_statfile_id_for_config(stcf);
+
+ // Get or create runtime for this statfile
+ auto maybe_statfile_rt = get_runtime_for_statfile(task, stcf, statfile_id);
+ if (maybe_statfile_rt) {
+ auto *statfile_rt = maybe_statfile_rt.value();
+
+ // Get token data for this class (class_idx)
+ lua_rawgeti(L, -1, class_idx);
+ if (lua_istable(L, -1)) {
+ parse_class_token_data(L, statfile_rt);
+ }
+ lua_pop(L, 1);
+ }
+
+ cur = g_list_next(cur);
+ class_idx++;
+ }
+ }
+}
+```
+
+### 5.3 Token Value Assignment
+
+```cpp
+bool redis_stat_runtime<T>::process_tokens(GPtrArray *tokens) const
+{
+ rspamd_token_t *tok;
+
+ if (!results) {
+ return false;
+ }
+
+ // results maps: token_index → token_count
+ for (auto [token_idx, token_count] : *results) {
+ tok = (rspamd_token_t *) g_ptr_array_index(tokens, token_idx - 1);
+
+ // CRITICAL: Set tok->values[id] where id is the statfile ID
+ tok->values[id] = token_count;
+ }
+
+ return true;
+}
+```
+
+## 6. Classification Algorithm Execution
+
+### 6.1 Multiclass Processing
+
+```c
+// src/libstat/classifiers/bayes.c
+gboolean bayes_classify_multiclass(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+{
+ struct bayes_task_closure cl;
+
+ // Initialize with class information from config
+ cl.num_classes = ctx->cfg->class_names->len;
+ cl.class_names = (char**)ctx->cfg->class_names->pdata;
+
+ // Process all tokens
+ for (i = 0; i < tokens->len; i++) {
+ rspamd_token_t *tok = g_ptr_array_index(tokens, i);
+ bayes_classify_token_multiclass(ctx, tok, &cl);
+ }
+}
+```
+
+### 6.2 Token Classification
+
+```c
+static void bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
+ rspamd_token_t *tok,
+ struct bayes_task_closure *cl)
+{
+ // For each statfile, check if it has data for this token
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ int id = g_array_index(ctx->statfiles_ids, int, i);
+ struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+
+ // CRITICAL: tok->values[id] must be set by Redis backend
+ double val = tok->values[id];
+
+ if (val > 0) {
+ // Find which class this statfile belongs to
+ for (j = 0; j < cl->num_classes; j++) {
+ if (strcmp(st->stcf->class_name, cl->class_names[j]) == 0) {
+ // Accumulate token evidence for this class
+ process_token_for_class(cl, j, val, st);
+ break;
+ }
+ }
+ }
+ }
+}
+```
+
+## 7. Critical Data Mapping
+
+### 7.1 Statfile ID Assignment
+
+**The Core Problem**: Ensuring correct mapping between:
+
+1. **Redis script class order**: `[H, S, N, T]` (array indices 1,2,3,4)
+2. **Statfile IDs**: Global statfile IDs assigned by `rspamd_stat_get_ctx()`
+3. **Runtime IDs**: Must match statfile IDs for `tok->values[id]` assignment
+
+### 7.2 Configuration to Runtime Mapping
+
+```c
+// Configuration defines classes
+statfile "BAYES_HAM" { class = "ham"; symbol = "BAYES_HAM"; } // Gets ID=0
+statfile "BAYES_SPAM" { class = "spam"; symbol = "BAYES_SPAM"; } // Gets ID=1
+statfile "BAYES_NEWS" { class = "news"; symbol = "BAYES_NEWS"; } // Gets ID=2
+
+// Redis backend maps: class_name → backend_label
+class_labels = {
+ "ham" = "H"; // Maps to Redis "H"
+ "spam" = "S"; // Maps to Redis "S"
+ "news" = "N"; // Maps to Redis "N"
+}
+
+// Redis script processes in label order: ["H", "S", "N"]
+// Returns data in same order: [ham_data, spam_data, news_data]
+
+// C++ must map:
+// redis_result[0] → statfile_id=0 (ham)
+// redis_result[1] → statfile_id=1 (spam)
+// redis_result[2] → statfile_id=2 (news)
+```
+
+### 7.3 Token Array Structure
+
+```c
+// For each token in message
+struct rspamd_token {
+ uint64_t data; // Token hash
+ float values[MAX_STATFILES]; // Values per statfile ID
+ // ...
+};
+
+// After Redis processing:
+// tok->values[0] = ham_count (from redis_result[0])
+// tok->values[1] = spam_count (from redis_result[1])
+// tok->values[2] = news_count (from redis_result[2])
+```
diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx
index 0f55a725c..f6ca9c12d 100644
--- a/src/libstat/backends/cdb_backend.cxx
+++ b/src/libstat/backends/cdb_backend.cxx
@@ -393,7 +393,6 @@ rspamd_cdb_process_tokens(struct rspamd_task *task,
gpointer runtime)
{
auto *cdbp = CDB_FROM_RAW(runtime);
- bool seen_values = false;
for (auto i = 0u; i < tokens->len; i++) {
rspamd_token_t *tok;
@@ -403,21 +402,13 @@ rspamd_cdb_process_tokens(struct rspamd_task *task,
if (res) {
tok->values[id] = res.value();
- seen_values = true;
}
else {
tok->values[id] = 0;
}
}
- if (seen_values) {
- if (cdbp->is_spam()) {
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
- }
- }
+ /* No longer need to set flags - multi-class handles missing data naturally */
return true;
}
@@ -488,4 +479,4 @@ void rspamd_cdb_close(gpointer ctx)
{
auto *cdbp = CDB_FROM_RAW(ctx);
delete cdbp;
-} \ No newline at end of file
+}
diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c
index 4430bb9a4..a6423a1e6 100644
--- a/src/libstat/backends/mmaped_file.c
+++ b/src/libstat/backends/mmaped_file.c
@@ -85,8 +85,7 @@ typedef struct {
#define RSPAMD_STATFILE_VERSION \
{ \
- '1', '2' \
- }
+ '1', '2'}
#define BACKUP_SUFFIX ".old"
static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool,
@@ -958,12 +957,7 @@ rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens,
tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2);
}
- if (mf->cf->is_spam) {
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
- }
+ /* No longer need to set flags - multi-class handles missing data naturally */
return TRUE;
}
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index 7137904e9..302778bcb 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -22,6 +22,7 @@
#include "contrib/fmt/include/fmt/base.h"
#include "libutil/cxx/error.hxx"
+#include <map>
#include <string>
#include <cstdint>
@@ -121,9 +122,9 @@ public:
}
static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
- bool is_spam) -> std::optional<redis_stat_runtime<T> *>
+ const char *class_label) -> std::optional<redis_stat_runtime<T> *>
{
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label);
auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
if (res) {
@@ -147,9 +148,15 @@ public:
rspamd_token_t *tok;
if (!results) {
+ msg_debug_bayes("process_tokens: no results available for statfile id=%d", id);
return false;
}
+ if (results->size() > 0) {
+ msg_debug_bayes("processing %uz tokens for statfile id=%d, class=%s",
+ results->size(), id, stcf->class_name ? stcf->class_name : "unknown");
+ }
+
for (auto [idx, val]: *results) {
tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
tok->values[id] = val;
@@ -158,12 +165,14 @@ public:
return true;
}
- auto save_in_mempool(bool is_spam) const
+ auto save_in_mempool(const char *class_label) const
{
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto var_name =
+ rspamd_mempool_strdup(task->task_pool,
+ fmt::format("{}_{}", redis_object_expanded, class_label).c_str());
/* We do not set destructor for the variable, as it should be already added on creation */
- rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
- msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
+ rspamd_mempool_set_variable(task->task_pool, var_name, (gpointer) this, nullptr);
+ msg_debug_bayes("saved runtime in mempool at %s", var_name);
}
};
@@ -178,6 +187,26 @@ rspamd_redis_stat_quark(void)
}
/*
+ * Get the class label for a statfile (for multi-class support)
+ */
+static const char *
+get_class_label(struct rspamd_statfile_config *stcf)
+{
+ /* Try to get the label from the classifier config first */
+ if (stcf->clcf && stcf->clcf->class_labels && stcf->class_name) {
+ const char *label = rspamd_config_get_class_label(stcf->clcf, stcf->class_name);
+ if (label) {
+ return label;
+ }
+ /* If no label mapping found, use class name directly */
+ return stcf->class_name;
+ }
+
+ /* Fallback to legacy binary classification */
+ return stcf->is_spam ? "S" : "H";
+}
+
+/*
* Non-static for lua unit testing
*/
gsize rspamd_redis_expand_object(const char *pattern,
@@ -235,6 +264,11 @@ gsize rspamd_redis_expand_object(const char *pattern,
if (rcpt) {
rspamd_mempool_set_variable(task->task_pool, "stat_user",
(gpointer) rcpt, nullptr);
+ msg_debug_bayes("redis expansion: found recipient '%s'", rcpt);
+ }
+ else {
+ msg_debug_bayes("redis expansion: no recipient found (deliver_to=%s)",
+ task->deliver_to ? task->deliver_to : "null");
}
}
@@ -448,6 +482,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
users_enabled = ucl_object_lookup_any(classifier_obj, "per_user",
"users_enabled", nullptr);
+ msg_debug_bayes_cfg("per-user lookup: users_enabled=%p", users_enabled);
if (users_enabled != nullptr) {
if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
backend->enable_users = ucl_object_toboolean(users_enabled);
@@ -485,9 +520,16 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
/* Default non-users statistics */
if (backend->enable_users || backend->cbref_user != -1) {
backend->redis_object = REDIS_DEFAULT_USERS_OBJECT;
+ msg_debug_bayes_cfg("using per-user Redis pattern: %s (enable_users=%s, cbref_user=%d)",
+ backend->redis_object, backend->enable_users ? "true" : "false",
+ backend->cbref_user);
}
else {
backend->redis_object = REDIS_DEFAULT_OBJECT;
+ msg_debug_bayes_cfg("using default Redis pattern: %s (enable_users=%s, cbref_user=%d)",
+ backend->redis_object,
+ backend->enable_users ? "true" : "false",
+ backend->cbref_user);
}
}
else {
@@ -541,7 +583,7 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx,
ucl_object_push_lua(L, st->classifier->cfg->opts, false);
ucl_object_push_lua(L, st->stcf->opts, false);
lua_pushstring(L, backend->stcf->symbol);
- lua_pushboolean(L, backend->stcf->is_spam);
+ lua_pushstring(L, get_class_label(backend->stcf)); /* Pass class label instead of boolean */
/* Push event loop if there is one available (e.g. we are not in rspamadm mode) */
if (ctx->event_loop) {
@@ -606,11 +648,20 @@ rspamd_redis_runtime(struct rspamd_task *task,
stcf->symbol);
return nullptr;
}
+ else {
+ msg_debug_bayes("redis object expanded: pattern='%s' -> expanded='%s' (learn=%s, symbol=%s)",
+ ctx->redis_object ? ctx->redis_object : "default",
+ object_expanded,
+ learn ? "true" : "false",
+ stcf->symbol);
+ }
+
+ const char *class_label = get_class_label(stcf);
/* Look for the cached results */
if (!learn) {
auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded, stcf->is_spam);
+ object_expanded, class_label);
if (maybe_existing) {
auto *rt = maybe_existing.value();
@@ -624,24 +675,62 @@ rspamd_redis_runtime(struct rspamd_task *task,
/* No cached result (or learn), create new one */
auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- if (!learn) {
- /*
- * For check, we also need to create the opposite class runtime to avoid
- * double call for Redis scripts.
- * This runtime will be filled later.
- */
- auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded,
- !stcf->is_spam);
-
- if (!maybe_opposite_rt) {
- auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- opposite_rt->save_in_mempool(!stcf->is_spam);
- opposite_rt->need_redis_call = false;
+ /* Find the statfile ID for the main runtime */
+ int main_id = _id; /* Use the passed _id parameter */
+ rt->id = main_id;
+ rt->stcf = stcf;
+
+ /* For classification, create runtimes for all other statfiles to avoid multiple Redis calls */
+ if (!learn && stcf->clcf && stcf->clcf->statfiles) {
+ GList *cur = stcf->clcf->statfiles;
+
+ while (cur) {
+ auto *other_stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *other_label = get_class_label(other_stcf);
+
+ /* Find the statfile ID by searching through all statfiles */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ int other_id = -1;
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *st = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (st->stcf == other_stcf) {
+ other_id = st->id;
+ msg_debug_bayes("found statfile mapping: %s (class=%s) → id=%d",
+ st->stcf->symbol, other_label, other_id);
+ break;
+ }
+ }
+
+ if (other_id == -1) {
+ msg_debug_bayes("statfile not found for class %s, skipping", other_label);
+ /* Skip if statfile not found */
+ cur = g_list_next(cur);
+ continue;
+ }
+
+ if (other_stcf == stcf) {
+ /* This is the main statfile, use the main runtime */
+ rt->save_in_mempool(other_label);
+ msg_debug_bayes("main runtime: statfile %s (class=%s) → id=%d",
+ stcf->symbol, other_label, rt->id);
+ }
+ else {
+ /* Create additional runtime for other statfile */
+ auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ other_rt->id = other_id;
+ other_rt->stcf = other_stcf;
+ other_rt->need_redis_call = false;
+ other_rt->save_in_mempool(other_label);
+ msg_debug_bayes("additional runtime: statfile %s (class=%s) → id=%d",
+ other_stcf->symbol, other_label, other_id);
+ }
+
+ cur = g_list_next(cur);
}
}
-
- rt->save_in_mempool(stcf->is_spam);
+ else {
+ rt->save_in_mempool(class_label);
+ }
return rt;
}
@@ -816,77 +905,306 @@ rspamd_redis_classified(lua_State *L)
if (rt == nullptr) {
msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
-
return 0;
}
bool result = lua_toboolean(L, 2);
if (result) {
- /* Indexes:
- * 3 - learned_ham (int)
- * 4 - learned_spam (int)
- * 5 - ham_tokens (pair<int, int>)
- * 6 - spam_tokens (pair<int, int>)
- */
-
- /*
- * We need to fill our runtime AND the opposite runtime
- */
- auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
- rt->learned = learned;
- redis_stat_runtime<float>::result_type *res;
-
- res = new redis_stat_runtime<float>::result_type();
-
- for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
- lua_rawgeti(L, -1, 1);
- auto idx = lua_tointeger(L, -1);
- lua_pop(L, 1);
-
- lua_rawgeti(L, -1, 2);
- auto value = lua_tonumber(L, -1);
- lua_pop(L, 1);
-
- res->emplace_back(idx, value);
- }
-
- rt->set_results(res);
- };
-
- auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- rt->redis_object_expanded,
- !rt->stcf->is_spam);
+ /* Check we have enough arguments and the result data is a table */
+ if (lua_gettop(L) < 3 || !lua_istable(L, 3)) {
+ msg_err_task("internal error: expected table result from Redis script, got %s",
+ lua_typename(L, lua_type(L, 3)));
+ rt->err = rspamd::util::error("invalid Redis script result format", 500);
+ return 0;
+ }
- if (!opposite_rt_maybe) {
- msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+ /* Redis returns [learned_counts_array, token_results_array]
+ * Both ordered the same way as statfiles in classifier */
+ size_t result_len = rspamd_lua_table_size(L, 3);
+ msg_debug_bayes("Redis result array length: %uz", result_len);
+ if (result_len != 2) {
+ msg_err_task("internal error: expected 2-element result from Redis script, got %uz", result_len);
+ rt->err = rspamd::util::error("invalid Redis script result format", 500);
return 0;
}
- if (rt->stcf->is_spam) {
- filler_func(rt, L, lua_tointeger(L, 4), 6);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ /* Get learned_counts_array and token_results_array */
+ lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */
+ lua_rawgeti(L, 3, 2); /* token_results -> position 5 */
+
+ /* First, process learned_counts */
+ if (lua_istable(L, 4) && rt->stcf->clcf) {
+ if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) {
+ /* Multi-class: use class_names order */
+ for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
+ const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
+
+ /* Find statfile with this class name */
+ GList *cur = rt->stcf->clcf->statfiles;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+ const char *class_label = get_class_label(stcf);
+
+ /* Get the runtime for this statfile */
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ auto *statfile_rt = maybe_rt.value();
+
+ /* Extract learned count using class index (1-based for Lua) */
+ lua_rawgeti(L, 4, class_idx + 1);
+ if (lua_isnumber(L, -1)) {
+ statfile_rt->learned = lua_tointeger(L, -1);
+ msg_debug_bayes("set learned count for class %s (label %s): %L",
+ class_name, class_label, statfile_rt->learned);
+ }
+ lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */
+ }
+ break; /* Found the statfile for this class */
+ }
+ cur = g_list_next(cur);
+ }
+ }
+ }
+ else {
+ /* Binary classification: process statfiles in order */
+ GList *cur = rt->stcf->clcf->statfiles;
+ unsigned int statfile_idx = 0;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *class_label = get_class_label(stcf);
+
+ /* Get the runtime for this statfile */
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ auto *statfile_rt = maybe_rt.value();
+
+ /* Extract learned count using statfile index (1-based for Lua) */
+ lua_rawgeti(L, 4, statfile_idx + 1);
+ if (lua_isnumber(L, -1)) {
+ statfile_rt->learned = lua_tointeger(L, -1);
+ msg_debug_bayes("set learned count for statfile %s (label %s): %L",
+ stcf->symbol, class_label, statfile_rt->learned);
+ }
+ lua_pop(L, 1); /* Pop learned_counts[statfile_idx + 1] */
+ }
+ cur = g_list_next(cur);
+ statfile_idx++;
+ }
+ }
}
- else {
- filler_func(rt, L, lua_tointeger(L, 3), 5);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+
+ /* Process token results */
+ if (lua_istable(L, 5) && rt->stcf->clcf) {
+ if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) {
+ /* Multi-class: use class_names order */
+ for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
+ const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
+
+ /* Find statfile with this class name */
+ GList *cur = rt->stcf->clcf->statfiles;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+ const char *class_label = get_class_label(stcf);
+
+ /* Find the statfile ID */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ struct rspamd_statfile *st = nullptr;
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (candidate->stcf == stcf) {
+ st = candidate;
+ break;
+ }
+ }
+
+ if (!st) {
+ msg_debug_bayes("statfile not found for class %s, skipping", class_name);
+ break;
+ }
+
+ /* Get or create runtime for this statfile */
+ auto *statfile_rt = rt; /* Use current runtime if it matches */
+ if (stcf != rt->stcf) {
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ statfile_rt = maybe_rt.value();
+ }
+ else {
+ msg_debug_bayes("runtime not found for class %s, skipping", class_label);
+ break;
+ }
+ }
+
+ /* Ensure correct statfile ID assignment */
+ statfile_rt->id = st->id;
+
+ /* Process token results using class index (1-based for Lua) */
+ lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */
+ if (lua_istable(L, -1)) {
+ /* Parse token results into statfile runtime */
+ auto *res = new std::vector<std::pair<int, float>>();
+
+ lua_pushnil(L); /* First key for iteration */
+ while (lua_next(L, -2) != 0) {
+ if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+ lua_rawgeti(L, -1, 1); /* token_index */
+ lua_rawgeti(L, -2, 2); /* token_count */
+
+ if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+ int token_idx = lua_tointeger(L, -2);
+ float token_count = lua_tonumber(L, -1);
+ res->emplace_back(token_idx, token_count);
+ }
+
+ lua_pop(L, 2); /* Pop token_index and token_count */
+ }
+ lua_pop(L, 1); /* Pop value, keep key for next iteration */
+ }
+
+ statfile_rt->set_results(res);
+ }
+ lua_pop(L, 1); /* Pop token_results[class_idx + 1] */
+ break; /* Found the statfile for this class */
+ }
+ cur = g_list_next(cur);
+ }
+ }
+ }
+ else {
+ /* Binary classification: process statfiles in order */
+ GList *cur = rt->stcf->clcf->statfiles;
+ unsigned int statfile_idx = 0;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *class_label = get_class_label(stcf);
+
+ /* Find the statfile ID */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ struct rspamd_statfile *st = nullptr;
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (candidate->stcf == stcf) {
+ st = candidate;
+ break;
+ }
+ }
+
+ if (!st) {
+ msg_debug_bayes("statfile not found for %s, skipping", stcf->symbol);
+ cur = g_list_next(cur);
+ statfile_idx++;
+ continue;
+ }
+
+ /* Get or create runtime for this statfile */
+ auto *statfile_rt = rt; /* Use current runtime if it matches */
+ if (stcf != rt->stcf) {
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ statfile_rt = maybe_rt.value();
+ }
+ else {
+ msg_debug_bayes("runtime not found for %s, skipping", class_label);
+ cur = g_list_next(cur);
+ statfile_idx++;
+ continue;
+ }
+ }
+
+ /* Ensure correct statfile ID assignment */
+ statfile_rt->id = st->id;
+
+ /* Process token results using statfile index (1-based for Lua) */
+ lua_rawgeti(L, 5, statfile_idx + 1); /* Get token_results[statfile_idx + 1] */
+ if (lua_istable(L, -1)) {
+ /* Parse token results into statfile runtime */
+ auto *res = new std::vector<std::pair<int, float>>();
+
+ lua_pushnil(L); /* First key for iteration */
+ while (lua_next(L, -2) != 0) {
+ if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+ lua_rawgeti(L, -1, 1); /* token_index */
+ lua_rawgeti(L, -2, 2); /* token_count */
+
+ if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+ int token_idx = lua_tointeger(L, -2);
+ float token_count = lua_tonumber(L, -1);
+ res->emplace_back(token_idx, token_count);
+ }
+
+ lua_pop(L, 2); /* Pop token_index and token_count */
+ }
+ lua_pop(L, 1); /* Pop value, keep key for next iteration */
+ }
+
+ statfile_rt->set_results(res);
+ msg_debug_bayes("set %uz token results for statfile %s (label %s, id=%d)",
+ res->size(), stcf->symbol, class_label, st->id);
+ }
+ lua_pop(L, 1); /* Pop token_results[statfile_idx + 1] */
+
+ cur = g_list_next(cur);
+ statfile_idx++;
+ }
+ }
}
- /* Mark task as being processed */
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ /* Clean up stack */
+ lua_pop(L, 2); /* Pop learned_counts and token_results */
- /* Process all tokens */
+ /* Process tokens for all runtimes */
g_assert(rt->tokens != nullptr);
- rt->process_tokens(rt->tokens);
- opposite_rt_maybe.value()->process_tokens(rt->tokens);
+
+ /* Process tokens for all statfiles */
+ if (rt->stcf->clcf && rt->stcf->clcf->statfiles) {
+ GList *cur = rt->stcf->clcf->statfiles;
+
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *class_label = get_class_label(stcf);
+
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ auto *statfile_rt = maybe_rt.value();
+ statfile_rt->process_tokens(rt->tokens);
+ }
+
+ cur = g_list_next(cur);
+ }
+ }
+ else {
+ /* Fallback: just process the main runtime */
+ rt->process_tokens(rt->tokens);
+ }
}
else {
/* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot classify task: %s",
- err_msg);
+ const char *err_msg = nullptr;
+ if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+ err_msg = lua_tostring(L, 3);
+ }
+ if (err_msg) {
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot classify task: %s", err_msg);
+ }
+ else {
+ rt->err = rspamd::util::error("unknown Redis script error", 500);
+ msg_err_task("cannot classify task: unknown Redis script error");
+ }
}
return 0;
@@ -929,7 +1247,42 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
rspamd_lua_task_push(L, task);
lua_pushstring(L, rt->redis_object_expanded);
lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
+
+ /* Send all class labels for multi-class support */
+ if (rt->stcf->clcf && rt->stcf->clcf->class_names &&
+ rt->stcf->clcf->class_names->len > 0) {
+ /* Multi-class: send array of class labels in deterministic order */
+ lua_createtable(L, rt->stcf->clcf->class_names->len, 0);
+ for (unsigned int i = 0; i < rt->stcf->clcf->class_names->len; i++) {
+ const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, i);
+ const char *class_label = nullptr;
+
+ /* Find the class label for this class name from any statfile with this class */
+ GList *cur = rt->stcf->clcf->statfiles;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+ class_label = get_class_label(stcf);
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+
+ if (class_label) {
+ lua_pushstring(L, class_label);
+ lua_rawseti(L, -2, i + 1); /* Lua arrays are 1-indexed */
+ }
+ }
+ }
+ else {
+ /* Binary classification: send both spam and ham labels for optimization */
+ lua_createtable(L, 2, 0);
+ lua_pushstring(L, "H"); /* ham */
+ lua_rawseti(L, -2, 1);
+ lua_pushstring(L, "S"); /* spam */
+ lua_rawseti(L, -2, 2);
+ }
+
lua_new_text(L, tokens_buf, tokens_len, false);
/* Store rt in random cookie */
@@ -979,13 +1332,31 @@ rspamd_redis_learned(lua_State *L)
bool result = lua_toboolean(L, 2);
if (result) {
- /* TODO: write it */
+ /* Learning successful - no complex data to process like in classification */
+ msg_debug_bayes("learned tokens successfully in Redis for symbol %s, class %s",
+ rt->stcf->symbol, get_class_label(rt->stcf));
+
+ /* Clear any previous error state */
+ rt->err = std::nullopt;
+
+ /* Learning operations don't return data structures to process,
+ * they just update Redis state. Success means the Redis script
+ * completed without errors. */
}
else {
/* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot learn task: %s", err_msg);
+ const char *err_msg = nullptr;
+ if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+ err_msg = lua_tostring(L, 3);
+ }
+ if (err_msg) {
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot learn task: %s", err_msg);
+ }
+ else {
+ rt->err = rspamd::util::error("unknown Redis script error", 500);
+ msg_err_task("cannot learn task: unknown Redis script error");
+ }
}
return 0;
@@ -1028,7 +1399,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
rspamd_lua_task_push(L, task);
lua_pushstring(L, rt->redis_object_expanded);
lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
+ lua_pushstring(L, get_class_label(rt->stcf)); /* Pass class label instead of boolean */
lua_pushstring(L, rt->stcf->symbol);
/* Detect unlearn */
@@ -1056,6 +1427,8 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
lua_new_text(L, text_tokens_buf, text_tokens_len, false);
}
+ msg_debug_bayes("called lua learn script for %s (cookie=%s)", rt->stcf->symbol, cookie);
+
if (lua_pcall(L, nargs, 0, err_idx) != 0) {
msg_err_task("call to script failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
diff --git a/src/libstat/backends/sqlite3_backend.c b/src/libstat/backends/sqlite3_backend.c
index 973dc30a7..8f29a3b4e 100644
--- a/src/libstat/backends/sqlite3_backend.c
+++ b/src/libstat/backends/sqlite3_backend.c
@@ -589,12 +589,7 @@ rspamd_sqlite3_process_tokens(struct rspamd_task *task,
}
}
- if (rt->cf->is_spam) {
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
- }
+ /* No longer need to set flags - multi-class handles missing data naturally */
}
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index 93b5149da..dbae98cc2 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -1,11 +1,11 @@
-/*-
- * Copyright 2016 Vsevolod Stakhov
+/*
+ * Copyright 2025 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -53,10 +53,26 @@ static double
inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
{
double prob, sum, m;
+ double log_prob, log_m;
int i;
errno = 0;
m = -value;
+
+ /* Handle extreme negative values that would cause exp() underflow */
+ if (value < -700) {
+ /* Very strong confidence, return 0 */
+ msg_debug_bayes("extreme negative value: %f, returning 0", value);
+ return 0.0;
+ }
+
+ /* Handle extreme positive values that would cause overflow */
+ if (value > 700) {
+ /* No confidence, return 1 */
+ msg_debug_bayes("extreme positive value: %f, returning 1", value);
+ return 1.0;
+ }
+
prob = exp(value);
if (errno == ERANGE) {
@@ -75,6 +91,8 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
}
sum = prob;
+ log_prob = value; /* log of current prob term */
+ log_m = log(fabs(m)); /* log of |m| for numerical stability */
msg_debug_bayes("m: %f, probability: %g", m, prob);
@@ -83,24 +101,60 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg)
* prob is e ^ x (small value since x is normally less than zero
* So we integrate over degrees of freedom and produce the total result
* from 1.0 (no confidence) to 0.0 (full confidence)
+ * Use logarithmic arithmetic to prevent overflow
*/
for (i = 1; i < freedom_deg; i++) {
- prob *= m / (double) i;
+ /* Calculate next term using logarithms to prevent overflow */
+ log_prob += log_m - log((double) i);
+
+ /* Check if the log probability is too negative (term becomes negligible) */
+ if (log_prob < -700) {
+ msg_debug_bayes("term %d became negligible, stopping series", i);
+ break;
+ }
+
+ /* Check if the log probability is too positive (would cause overflow) */
+ if (log_prob > 700) {
+ msg_debug_bayes("series diverging at term %d, returning 1.0", i);
+ return 1.0;
+ }
+
+ prob = exp(log_prob);
sum += prob;
- msg_debug_bayes("i=%d, probability: %g, sum: %g", i, prob, sum);
+ msg_debug_bayes("i=%d, log_prob: %g, probability: %g, sum: %g", i, log_prob, prob, sum);
+
+ /* Early termination if sum is getting too large */
+ if (sum > 1e10) {
+ msg_debug_bayes("sum too large (%g), returning 1.0", sum);
+ return 1.0;
+ }
}
return MIN(1.0, sum);
}
struct bayes_task_closure {
- double ham_prob;
- double spam_prob;
+ double ham_prob; /* Kept for binary compatibility */
+ double spam_prob; /* Kept for binary compatibility */
+ double meta_skip_prob;
+ uint64_t processed_tokens;
+ uint64_t total_hits;
+ uint64_t text_tokens;
+ struct rspamd_task *task;
+};
+
+/* Multi-class classification closure */
+struct bayes_multiclass_closure {
+ double *class_log_probs; /* Array of log probabilities for each class */
+ uint64_t *class_learns; /* Learning counts for each class */
+ char **class_names; /* Array of class names */
+ unsigned int num_classes; /* Number of classes */
double meta_skip_prob;
uint64_t processed_tokens;
uint64_t total_hits;
uint64_t text_tokens;
struct rspamd_task *task;
+ struct rspamd_classifier_config *cfg;
};
/*
@@ -122,7 +176,6 @@ bayes_classify_token(struct rspamd_classifier *ctx,
unsigned int spam_count = 0, ham_count = 0, total_count = 0;
struct rspamd_statfile *st;
struct rspamd_task *task;
- const char *token_type = "txt";
double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
ham_prob, fw, w, val;
@@ -211,41 +264,379 @@ bayes_classify_token(struct rspamd_classifier *ctx,
if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
cl->text_tokens++;
}
+ }
+}
+
+/*
+ * Multinomial token classification for multi-class Bayes
+ */
+static void
+bayes_classify_token_multiclass(struct rspamd_classifier *ctx,
+ rspamd_token_t *tok,
+ struct bayes_multiclass_closure *cl)
+{
+ unsigned int i, j;
+ int id;
+ struct rspamd_statfile *st;
+ struct rspamd_task *task;
+ double val, fw, w;
+ guint64 *class_counts;
+ guint64 total_count = 0;
+
+ task = cl->task;
+
+ /* Skip meta tokens probabilistically if configured */
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
+ val = rspamd_random_double_fast();
+ if (val <= cl->meta_skip_prob) {
+ return;
+ }
+ }
+
+ /* Allocate array for class counts */
+ class_counts = g_alloca(cl->num_classes * sizeof(guint64));
+ memset(class_counts, 0, cl->num_classes * sizeof(guint64));
+
+ /* Collect counts for each class */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+ val = tok->values[id];
+
+ if (val > 0) {
+ /* Direct O(1) class index lookup instead of O(N) string comparison */
+ if (st->stcf->class_name && st->stcf->class_index < cl->num_classes) {
+ unsigned int class_idx = st->stcf->class_index;
+ class_counts[class_idx] += val;
+ total_count += val;
+ cl->total_hits += val;
+ }
+ else {
+ msg_debug_bayes("invalid class_index %ud >= %ud for statfile %s",
+ st->stcf->class_index, cl->num_classes, st->stcf->symbol);
+ }
+ }
+ }
+
+ /* Calculate multinomial probability for this token */
+ if (total_count >= ctx->cfg->min_token_hits) {
+ /* Feature weight calculation */
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
+ fw = 1.0;
+ }
else {
- token_type = "meta";
+ fw = feature_weight[tok->window_idx % G_N_ELEMENTS(feature_weight)];
}
- if (tok->t1 && tok->t2) {
- msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
- "total_count: %ud, "
- "spam_count: %ud, ham_count: %ud,"
- "spam_prob: %.3f, ham_prob: %.3f, "
- "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
- "current spam probability: %.3f, current ham probability: %.3f",
- token_type,
- tok->data,
- (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
- (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
- fw, w, total_count, spam_count, ham_count,
- spam_prob, ham_prob,
- bayes_spam_prob, bayes_ham_prob,
- cl->spam_prob, cl->ham_prob);
+ w = (fw * total_count) / (1.0 + fw * total_count);
+
+ /* Apply multinomial model for each class */
+ for (j = 0; j < cl->num_classes; j++) {
+ /* Skip classes with insufficient learns */
+ if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) {
+ continue;
+ }
+
+ double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]);
+ double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes);
+
+ /* Ensure probability is properly bounded [0, 1] */
+ class_prob = MAX(0.0, MIN(1.0, class_prob));
+
+ /* Skip probabilities too close to uniform (1/num_classes) */
+ double uniform_prior = 1.0 / cl->num_classes;
+ if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) {
+ continue;
+ }
+
+ cl->class_log_probs[j] += log(class_prob);
+ }
+
+ cl->processed_tokens++;
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ cl->text_tokens++;
+ }
+
+ /* Per-token debug logging removed to reduce verbosity */
+ }
+}
+
+/*
+ * Multinomial Bayes classification with Fisher confidence
+ */
+static gboolean
+bayes_classify_multiclass(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+{
+ struct bayes_multiclass_closure cl;
+ rspamd_token_t *tok;
+ unsigned int i, j, text_tokens = 0;
+ int id;
+ struct rspamd_statfile *st;
+ rspamd_multiclass_result_t *result;
+ double *normalized_probs;
+ double max_log_prob = -INFINITY;
+ unsigned int winning_class_idx = 0;
+ double confidence;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+
+ /* Initialize multi-class closure */
+ memset(&cl, 0, sizeof(cl));
+ cl.task = task;
+ cl.cfg = ctx->cfg;
+
+ /* Get class information from classifier config */
+ if (!ctx->cfg->class_names) {
+ msg_debug_bayes("no class_names array in classifier config");
+ return TRUE; /* Fall back to binary mode */
+ }
+ if (ctx->cfg->class_names->len < 2) {
+ msg_debug_bayes("insufficient classes: %ud < 2", (unsigned int) ctx->cfg->class_names->len);
+ return TRUE; /* Fall back to binary mode */
+ }
+ if (!ctx->cfg->class_names->pdata) {
+ msg_debug_bayes("class_names->pdata is NULL");
+ return TRUE; /* Fall back to binary mode */
+ }
+
+ cl.num_classes = ctx->cfg->class_names->len;
+ cl.class_names = (char **) ctx->cfg->class_names->pdata;
+
+ /* Debug: verify class names are accessible */
+ msg_debug_bayes("multiclass setup: ctx->cfg->class_names=%p, len=%ud, pdata=%p",
+ ctx->cfg->class_names, (unsigned int) ctx->cfg->class_names->len, ctx->cfg->class_names->pdata);
+ msg_debug_bayes("multiclass setup: cl.num_classes=%ud, cl.class_names=%p",
+ cl.num_classes, cl.class_names);
+ cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double));
+ cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t));
+
+ /* Initialize probabilities and get learning counts */
+ for (i = 0; i < cl.num_classes; i++) {
+ cl.class_log_probs[i] = 0.0;
+ cl.class_learns[i] = 0;
+ }
+
+ /* Collect learning counts for each class */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+
+ for (j = 0; j < cl.num_classes; j++) {
+ if (st->stcf->class_name &&
+ strcmp(st->stcf->class_name, cl.class_names[j]) == 0) {
+ cl.class_learns[j] += st->backend->total_learns(task,
+ g_ptr_array_index(task->stat_runtimes, id), ctx->ctx);
+ break;
+ }
+ }
+ }
+
+ /* Check minimum learns requirement - count viable classes */
+ unsigned int viable_classes = 0;
+ if (ctx->cfg->min_learns > 0) {
+ for (i = 0; i < cl.num_classes; i++) {
+ if (cl.class_learns[i] >= ctx->cfg->min_learns) {
+ viable_classes++;
+ }
+ else {
+ msg_info_task("class %s excluded from classification: %uL learns < %ud minimum",
+ cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
+ }
+ }
+
+ if (viable_classes == 0) {
+ msg_info_task("no classes have sufficient training samples for classification");
+ return TRUE;
+ }
+
+ msg_info_bayes("multiclass classification: %ud/%ud classes have sufficient learns",
+ viable_classes, cl.num_classes);
+ }
+
+ /* Count text tokens */
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ text_tokens++;
+ }
+ }
+
+ if (text_tokens == 0) {
+ msg_info_task("skipped classification as there are no text tokens. "
+ "Total tokens: %ud",
+ tokens->len);
+ return TRUE;
+ }
+
+ /* Set meta token skip probability */
+ if (text_tokens > tokens->len - text_tokens) {
+ cl.meta_skip_prob = 0.0;
+ }
+ else {
+ cl.meta_skip_prob = 1.0 - (double) text_tokens / tokens->len;
+ }
+
+ /* Process all tokens */
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ bayes_classify_token_multiclass(ctx, tok, &cl);
+ }
+
+ if (cl.processed_tokens == 0) {
+ /* Debug: check why no tokens were processed */
+ msg_debug_bayes("examining token values for debugging:");
+ for (i = 0; i < MIN(tokens->len, 10); i++) { /* Check first 10 tokens */
+ tok = g_ptr_array_index(tokens, i);
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index(ctx->statfiles_ids, int, j);
+ if (tok->values[id] > 0) {
+ struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ msg_debug_bayes("token %ud: values[%d] = %.2f (class=%s, symbol=%s)",
+ i, id, tok->values[id],
+ st->stcf->class_name ? st->stcf->class_name : "unknown",
+ st->stcf->symbol);
+ }
+ }
+ }
+
+ msg_info_bayes("no tokens found in bayes database "
+ "(%ud total tokens, %ud text tokens), ignore stats",
+ tokens->len, text_tokens);
+ return TRUE;
+ }
+
+ if (ctx->cfg->min_tokens > 0 &&
+ cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) {
+ msg_info_bayes("ignore bayes probability since we have "
+ "found too few text tokens: %uL (of %ud checked), "
+ "at least %d required",
+ cl.text_tokens, text_tokens,
+ (int) (ctx->cfg->min_tokens * 0.1));
+ return TRUE;
+ }
+
+ /* Normalize probabilities using softmax */
+ normalized_probs = g_alloca(cl.num_classes * sizeof(double));
+
+ /* Find maximum for numerical stability - only consider classes with sufficient training */
+ for (i = 0; i < cl.num_classes; i++) {
+ msg_debug_bayes("class %s, log_prob: %.2f", cl.class_names[i], cl.class_log_probs[i]);
+ /* Only consider classes that have sufficient training data */
+ if (ctx->cfg->min_learns > 0 && cl.class_learns[i] < ctx->cfg->min_learns) {
+ msg_debug_bayes("skipping class %s in winner selection: %uL learns < %ud minimum",
+ cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns);
+ continue;
+ }
+ if (cl.class_log_probs[i] > max_log_prob) {
+ max_log_prob = cl.class_log_probs[i];
+ winning_class_idx = i;
+ }
+ }
+
+ /* Apply softmax normalization */
+ double sum_exp = 0.0;
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] = exp(cl.class_log_probs[i] - max_log_prob);
+ sum_exp += normalized_probs[i];
+ }
+
+ if (sum_exp > 0) {
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] /= sum_exp;
+ }
+ }
+ else {
+ /* Fallback to uniform distribution */
+ for (i = 0; i < cl.num_classes; i++) {
+ normalized_probs[i] = 1.0 / cl.num_classes;
+ }
+ }
+
+ /* Calculate confidence using Fisher method for the winning class */
+ if (max_log_prob > -300) {
+ if (max_log_prob > 0) {
+ /* Positive log prob means very strong evidence - high confidence */
+ confidence = 0.95; /* High confidence for positive log probabilities */
+ msg_debug_bayes("positive log_prob (%g), setting high confidence", max_log_prob);
}
else {
- msg_debug_bayes("token(%s) %uL <?:?>: weight: %f, cf: %f, "
- "total_count: %ud, "
- "spam_count: %ud, ham_count: %ud,"
- "spam_prob: %.3f, ham_prob: %.3f, "
- "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
- "current spam probability: %.3f, current ham probability: %.3f",
- token_type,
- tok->data,
- fw, w, total_count, spam_count, ham_count,
- spam_prob, ham_prob,
- bayes_spam_prob, bayes_ham_prob,
- cl->spam_prob, cl->ham_prob);
+ /* Negative log prob - use Fisher method as intended */
+ double fisher_result = inv_chi_square(task, max_log_prob, cl.processed_tokens);
+ confidence = 1.0 - fisher_result;
+
+ msg_debug_bayes("fisher_result: %g, max_log_prob: %g, condition check: fisher_result > 0.999 = %s, max_log_prob > -50 = %s",
+ fisher_result, max_log_prob,
+ fisher_result > 0.999 ? "true" : "false",
+ max_log_prob > -50 ? "true" : "false");
+
+ /* Handle case where Fisher method indicates extreme confidence */
+ if (fisher_result > 0.999 && max_log_prob > -100) {
+ /* Large magnitude negative log prob means strong evidence */
+ confidence = 0.90;
+ msg_debug_bayes("extreme negative log_prob (%g), setting high confidence", max_log_prob);
+ }
}
}
+ else {
+ confidence = normalized_probs[winning_class_idx];
+ }
+
+ /* Create and store multiclass result */
+ result = g_new0(rspamd_multiclass_result_t, 1);
+ result->class_names = g_new(char *, cl.num_classes);
+ result->probabilities = g_new(double, cl.num_classes);
+ result->num_classes = cl.num_classes;
+ result->winning_class = cl.class_names[winning_class_idx]; /* Reference, not copy */
+ result->confidence = confidence;
+
+ for (i = 0; i < cl.num_classes; i++) {
+ result->class_names[i] = g_strdup(cl.class_names[i]);
+ result->probabilities[i] = normalized_probs[i];
+ }
+
+ rspamd_task_set_multiclass_result(task, result);
+
+ msg_info_bayes("MULTICLASS_RESULT: winning_class='%s', confidence=%.3f, normalized_prob=%.3f, tokens=%uL",
+ cl.class_names[winning_class_idx], confidence,
+ normalized_probs[winning_class_idx], cl.processed_tokens);
+
+ /* Insert symbol for winning class if confidence is significant */
+ if (confidence > 0.05) {
+ char sumbuf[32];
+ double final_prob = rspamd_normalize_probability(confidence, 0.5);
+
+ rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%", confidence * 100.0);
+
+ /* Find the statfile for the winning class to get the symbol */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, int, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+
+ if (st->stcf->class_name &&
+ strcmp(st->stcf->class_name, cl.class_names[winning_class_idx]) == 0) {
+ msg_info_bayes("SYMBOL_INSERT: symbol='%s', final_prob=%.3f, confidence_display='%s'",
+ st->stcf->symbol, final_prob, sumbuf);
+ rspamd_task_insert_result(task, st->stcf->symbol, final_prob, sumbuf);
+ break;
+ }
+ }
+
+ msg_debug_bayes("multiclass classification: winning class '%s' with "
+ "probability %.3f, confidence %.3f, %uL tokens processed",
+ cl.class_names[winning_class_idx],
+ normalized_probs[winning_class_idx],
+ confidence, cl.processed_tokens);
+ }
+ else {
+ msg_info_bayes("SYMBOL_SKIPPED: confidence=%.3f <= 0.05, no symbol inserted", confidence);
+ }
+
+ return TRUE;
}
@@ -279,6 +670,37 @@ bayes_classify(struct rspamd_classifier *ctx,
g_assert(ctx != NULL);
g_assert(tokens != NULL);
+ /* Check if this is a multi-class classifier */
+ msg_debug_bayes("classification check: class_names=%p, len=%uz",
+ ctx->cfg->class_names,
+ ctx->cfg->class_names ? ctx->cfg->class_names->len : 0);
+
+ if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) {
+ /* Verify that at least one statfile has class_name set (indicating new multi-class config) */
+ gboolean has_class_names = FALSE;
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ int id = g_array_index(ctx->statfiles_ids, int, i);
+ struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ msg_debug_bayes("checking statfile %s: class_name=%s, is_spam_converted=%s",
+ st->stcf->symbol,
+ st->stcf->class_name ? st->stcf->class_name : "NULL",
+ st->stcf->is_spam_converted ? "true" : "false");
+ if (st->stcf->class_name) {
+ has_class_names = TRUE;
+ }
+ }
+
+ msg_debug_bayes("has_class_names=%s", has_class_names ? "true" : "false");
+
+ if (has_class_names) {
+ msg_debug_bayes("using multiclass classification with %ud classes",
+ (unsigned int) ctx->cfg->class_names->len);
+ return bayes_classify_multiclass(ctx, tokens, task);
+ }
+ }
+
+ /* Fall back to binary classification */
+ msg_debug_bayes("using binary classification");
memset(&cl, 0, sizeof(cl));
cl.task = task;
@@ -286,14 +708,14 @@ bayes_classify(struct rspamd_classifier *ctx,
if (ctx->cfg->min_learns > 0) {
if (ctx->ham_learns < ctx->cfg->min_learns) {
msg_info_task("not classified as ham. The ham class needs more "
- "training samples. Currently: %ul; minimum %ud required",
+ "training samples. Currently: %uL; minimum %ud required",
ctx->ham_learns, ctx->cfg->min_learns);
return TRUE;
}
if (ctx->spam_learns < ctx->cfg->min_learns) {
msg_info_task("not classified as spam. The spam class needs more "
- "training samples. Currently: %ul; minimum %ud required",
+ "training samples. Currently: %uL; minimum %ud required",
ctx->spam_learns, ctx->cfg->min_learns);
return TRUE;
@@ -374,7 +796,7 @@ bayes_classify(struct rspamd_classifier *ctx,
final_prob = (s + 1.0 - h) / 2.;
msg_debug_bayes(
"got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
- " %L tokens processed of %ud total tokens;"
+ " %uL tokens processed of %ud total tokens;"
" %uL text tokens found of %ud text tokens)",
cl.ham_prob,
h,
@@ -549,3 +971,155 @@ bayes_learn_spam(struct rspamd_classifier *ctx,
return TRUE;
}
+
+gboolean
+bayes_learn_class(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ const char *class_name,
+ gboolean unlearn,
+ GError **err)
+{
+ unsigned int i, j, total_cnt;
+ int id;
+ struct rspamd_statfile *st;
+ rspamd_token_t *tok;
+ gboolean incrementing;
+ unsigned int *class_counts = NULL;
+ struct rspamd_statfile **class_statfiles = NULL;
+ unsigned int num_classes = 0;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+ g_assert(class_name != NULL);
+
+ msg_info_bayes("LEARN_CLASS: class='%s', unlearn=%s, tokens=%ud",
+ class_name, unlearn ? "true" : "false", tokens->len);
+
+ incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+
+ /* Count classes and prepare arrays for multi-class learning */
+ if (ctx->cfg->class_names && ctx->cfg->class_names->len > 0) {
+ num_classes = ctx->cfg->class_names->len;
+ class_counts = g_alloca(num_classes * sizeof(unsigned int));
+ class_statfiles = g_alloca(num_classes * sizeof(struct rspamd_statfile *));
+ memset(class_counts, 0, num_classes * sizeof(unsigned int));
+ memset(class_statfiles, 0, num_classes * sizeof(struct rspamd_statfile *));
+ }
+
+ for (i = 0; i < tokens->len; i++) {
+ total_cnt = 0;
+ tok = g_ptr_array_index(tokens, i);
+
+ /* Reset class counts for this token */
+ if (num_classes > 0) {
+ memset(class_counts, 0, num_classes * sizeof(unsigned int));
+ }
+
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index(ctx->statfiles_ids, int, j);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+
+ /* Determine if this statfile matches our target class */
+ gboolean is_target_class = FALSE;
+ if (st->stcf->class_name) {
+ /* Multi-class: exact class name match */
+ is_target_class = (strcmp(st->stcf->class_name, class_name) == 0);
+ }
+ else {
+ /* Legacy binary: map class_name to spam/ham */
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ is_target_class = st->stcf->is_spam;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ is_target_class = !st->stcf->is_spam;
+ }
+ }
+
+ if (is_target_class) {
+ /* Learning: increment the target class */
+ if (incrementing) {
+ tok->values[id] = 1;
+ }
+ else {
+ tok->values[id]++;
+ }
+ total_cnt += tok->values[id];
+
+ /* Track class counts for debugging */
+ if (num_classes > 0) {
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+ class_counts[k] += tok->values[id];
+ class_statfiles[k] = st;
+ break;
+ }
+ }
+ }
+ }
+ else {
+ /* Unlearning: decrement other classes if unlearn flag is set */
+ if (tok->values[id] > 0 && unlearn) {
+ if (incrementing) {
+ tok->values[id] = -1;
+ }
+ else {
+ tok->values[id]--;
+ }
+ total_cnt += tok->values[id];
+
+ /* Track class counts for debugging */
+ if (num_classes > 0) {
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) {
+ class_counts[k] += tok->values[id];
+ class_statfiles[k] = st;
+ break;
+ }
+ }
+ }
+ }
+ else if (incrementing) {
+ tok->values[id] = 0;
+ }
+ }
+ }
+
+ /* Debug logging */
+ if (tok->t1 && tok->t2) {
+ if (num_classes > 0) {
+ GString *debug_str = g_string_new("");
+ for (unsigned int k = 0; k < num_classes; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k);
+ g_string_append_printf(debug_str, "%s:%d ", check_class, class_counts[k]);
+ }
+ msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+ "class_counts: %s",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ tok->window_idx, total_cnt, debug_str->str);
+ g_string_free(debug_str, TRUE);
+ }
+ else {
+ msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+ "class: %s",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ tok->window_idx, total_cnt, class_name);
+ }
+ }
+ else {
+ msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, "
+ "class: %s",
+ tok->data,
+ tok->window_idx, total_cnt, class_name);
+ }
+ }
+
+ return TRUE;
+}
diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h
index 22978e673..cab658146 100644
--- a/src/libstat/classifiers/classifiers.h
+++ b/src/libstat/classifiers/classifiers.h
@@ -54,6 +54,13 @@ struct rspamd_stat_classifier {
gboolean unlearn,
GError **err);
+ gboolean (*learn_class_func)(struct rspamd_classifier *ctx,
+ GPtrArray *input,
+ struct rspamd_task *task,
+ const char *class_name,
+ gboolean unlearn,
+ GError **err);
+
void (*fin_func)(struct rspamd_classifier *cl);
};
@@ -73,6 +80,13 @@ gboolean bayes_learn_spam(struct rspamd_classifier *ctx,
gboolean unlearn,
GError **err);
+gboolean bayes_learn_class(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ const char *class_name,
+ gboolean unlearn,
+ GError **err);
+
void bayes_fin(struct rspamd_classifier *);
/* Generic lua classifier */
diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx
index 0de5cd094..afefeadcd 100644
--- a/src/libstat/learn_cache/redis_cache.cxx
+++ b/src/libstat/learn_cache/redis_cache.cxx
@@ -152,6 +152,33 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
return (void *) ctx;
}
+/* Get class ID using rspamd_cryptobox_fast_hash */
+static uint64_t
+rspamd_stat_cache_get_class_id(const char *class_name)
+{
+ if (!class_name) {
+ return 0;
+ }
+
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ return 1;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ return 0;
+ }
+ else {
+ /* For other classes, use rspamd_cryptobox_fast_hash */
+ uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0);
+
+ /* Ensure we don't get 0 or 1 (reserved for ham/spam) */
+ if (hash == 0 || hash == 1) {
+ hash += 2;
+ }
+
+ return hash;
+ }
+}
+
static int
rspamd_stat_cache_checked(lua_State *L)
{
@@ -161,23 +188,39 @@ rspamd_stat_cache_checked(lua_State *L)
if (res) {
auto val = lua_tointeger(L, 3);
- if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
- (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
- /* Already learned */
- msg_info_task("<%s> has been already "
- "learned as %s, ignore it",
- MESSAGE_FIELD(task, message_id),
- (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
- task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ /* Get the class being learned */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (!autolearn_class) {
+ /* Fallback to binary flags for backward compatibility */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ autolearn_class = "spam";
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ autolearn_class = "ham";
+ }
}
- else {
- /* Unlearn flag */
- task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+
+ if (autolearn_class) {
+ uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class);
+
+ if ((uint64_t) val == expected_id) {
+ /* Already learned */
+ msg_info_task("<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ autolearn_class);
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ }
+ else {
+ /* Different class learned, unlearn flag */
+ msg_debug_task("<%s> cached value %ld != expected %lu for class %s, will unlearn",
+ MESSAGE_FIELD(task, message_id),
+ val, expected_id, autolearn_class);
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ }
}
}
- /* Ignore errors for now, as we can do nothing about them at the moment */
-
return 0;
}
@@ -235,9 +278,20 @@ int rspamd_stat_cache_redis_learn(struct rspamd_task *task,
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
rspamd_lua_task_push(L, task);
lua_pushstring(L, h);
- lua_pushboolean(L, is_spam);
- if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ /* Get the class being learned - prefer multiclass over binary */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (!autolearn_class) {
+ /* Fallback to binary flag for backward compatibility */
+ autolearn_class = is_spam ? "spam" : "ham";
+ }
+
+ /* Push class name and class ID */
+ lua_pushstring(L, autolearn_class);
+ uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class);
+ lua_pushinteger(L, class_id);
+
+ if (lua_pcall(L, 4, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return RSPAMD_LEARN_IGNORE;
diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h
index 811566ad3..aa6111a8b 100644
--- a/src/libstat/stat_api.h
+++ b/src/libstat/stat_api.h
@@ -108,6 +108,23 @@ rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task,
GError **err);
/**
+ * Learn task as a specific class, task must be processed prior to this call
+ * @param task task to learn
+ * @param class_name name of the class to learn (e.g., "spam", "ham", "transactional")
+ * @param L lua state
+ * @param classifier NULL to learn all classifiers, name to learn a specific one
+ * @param stage learning stage
+ * @param err error returned
+ * @return TRUE if task has been learned
+ */
+rspamd_stat_result_t rspamd_stat_learn_class(struct rspamd_task *task,
+ const char *class_name,
+ lua_State *L,
+ const char *classifier,
+ unsigned int stage,
+ GError **err);
+
+/**
* Get the overall statistics for all statfile backends
* @param cfg configuration
* @param total_learns the total number of learns is stored here
@@ -120,6 +137,43 @@ rspamd_stat_result_t rspamd_stat_statistics(struct rspamd_task *task,
void rspamd_stat_unload(void);
+/**
+ * Multi-class classification result structure
+ */
+typedef struct {
+ char **class_names; /**< Array of class names */
+ double *probabilities; /**< Array of probabilities for each class */
+ unsigned int num_classes; /**< Number of classes */
+ const char *winning_class; /**< Name of the winning class (reference, not owned) */
+ double confidence; /**< Confidence of the winning class */
+} rspamd_multiclass_result_t;
+
+/**
+ * Set multi-class classification result for a task
+ */
+void rspamd_task_set_multiclass_result(struct rspamd_task *task,
+ rspamd_multiclass_result_t *result);
+
+/**
+ * Get multi-class classification result from a task
+ */
+rspamd_multiclass_result_t *rspamd_task_get_multiclass_result(struct rspamd_task *task);
+
+/**
+ * Free multi-class result structure
+ */
+void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result);
+
+/**
+ * Set autolearn class for a task
+ */
+void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name);
+
+/**
+ * Get autolearn class from a task
+ */
+const char *rspamd_task_get_autolearn_class(struct rspamd_task *task);
+
#ifdef __cplusplus
}
#endif
diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c
index 8a5313df2..5ada7d468 100644
--- a/src/libstat/stat_config.c
+++ b/src/libstat/stat_config.c
@@ -28,6 +28,7 @@ static struct rspamd_stat_classifier lua_classifier = {
.init_func = lua_classifier_init,
.classify_func = lua_classifier_classify,
.learn_spam_func = lua_classifier_learn_spam,
+ .learn_class_func = NULL, /* TODO: implement lua multi-class learning */
.fin_func = NULL,
};
@@ -37,6 +38,7 @@ static struct rspamd_stat_classifier stat_classifiers[] = {
.init_func = bayes_init,
.classify_func = bayes_classify,
.learn_spam_func = bayes_learn_spam,
+ .learn_class_func = bayes_learn_class,
.fin_func = bayes_fin,
}};
@@ -68,8 +70,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = {
.dec_learns = rspamd_##eltn##_dec_learns, \
.get_stat = rspamd_##eltn##_get_stat, \
.load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
- .close = rspamd_##eltn##_close \
- }
+ .close = rspamd_##eltn##_close}
#define RSPAMD_STAT_BACKEND_ELT_READONLY(nam, eltn) \
{ \
.name = #nam, \
@@ -85,8 +86,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = {
.dec_learns = NULL, \
.get_stat = rspamd_##eltn##_get_stat, \
.load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
- .close = rspamd_##eltn##_close \
- }
+ .close = rspamd_##eltn##_close}
static struct rspamd_stat_backend stat_backends[] = {
RSPAMD_STAT_BACKEND_ELT(mmap, mmaped_file),
@@ -101,8 +101,7 @@ static struct rspamd_stat_backend stat_backends[] = {
.runtime = rspamd_stat_cache_##eltn##_runtime, \
.check = rspamd_stat_cache_##eltn##_check, \
.learn = rspamd_stat_cache_##eltn##_learn, \
- .close = rspamd_stat_cache_##eltn##_close \
- }
+ .close = rspamd_stat_cache_##eltn##_close}
static struct rspamd_stat_cache stat_caches[] = {
RSPAMD_STAT_CACHE_ELT(sqlite3, sqlite3),
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 176064087..11b31decc 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -32,6 +32,78 @@
static const double similarity_threshold = 80.0;
+void rspamd_task_set_multiclass_result(struct rspamd_task *task, rspamd_multiclass_result_t *result)
+{
+ g_assert(task != NULL);
+ g_assert(result != NULL);
+
+ rspamd_mempool_set_variable(task->task_pool, "multiclass_bayes_result", result,
+ (rspamd_mempool_destruct_t) rspamd_multiclass_result_free);
+}
+
+rspamd_multiclass_result_t *
+rspamd_task_get_multiclass_result(struct rspamd_task *task)
+{
+ g_assert(task != NULL);
+
+ return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool,
+ "multiclass_bayes_result");
+}
+
+void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result)
+{
+ if (result == NULL) {
+ return;
+ }
+
+ g_free(result->class_names);
+ g_free(result->probabilities);
+ /* winning_class is a reference, not owned - don't free */
+ g_free(result);
+}
+
+void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name)
+{
+ g_assert(task != NULL);
+ g_assert(class_name != NULL);
+
+ /* Store the class name in the mempool */
+ const char *class_name_copy = rspamd_mempool_strdup(task->task_pool, class_name);
+ rspamd_mempool_set_variable(task->task_pool, "autolearn_class",
+ (gpointer) class_name_copy, NULL);
+
+ /* Set the appropriate flags */
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_CLASS;
+
+ /* For backward compatibility, also set binary flags */
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ }
+}
+
+const char *
+rspamd_task_get_autolearn_class(struct rspamd_task *task)
+{
+ g_assert(task != NULL);
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ return (const char *) rspamd_mempool_get_variable(task->task_pool, "autolearn_class");
+ }
+
+ /* Fallback to binary flags for backward compatibility */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ return "spam";
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ return "ham";
+ }
+
+ return NULL;
+}
+
static void
rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task)
@@ -394,18 +466,9 @@ rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx,
}
/*
- * Do not classify a message if some class is missing
+ * Multi-class approach: don't check for missing classes
+ * Missing tokens naturally result in 0 probability
*/
- if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) {
- msg_info_task("skip statistics as SPAM class is missing");
-
- return;
- }
- if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) {
- msg_info_task("skip statistics as HAM class is missing");
-
- return;
- }
for (i = 0; i < st_ctx->classifiers->len; i++) {
cl = g_ptr_array_index(st_ctx->classifiers, i);
@@ -565,7 +628,24 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
if (sel->cache && sel->cachecf) {
rt = cl->cache->runtime(task, sel->cachecf, FALSE);
- learn_res = cl->cache->check(task, spam, rt);
+
+ /* For multi-class learning, determine spam boolean from class name if available */
+ gboolean cache_spam = spam; /* Default to original spam parameter */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+ cache_spam = TRUE;
+ }
+ else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+ cache_spam = FALSE;
+ }
+ else {
+ /* For other classes, use a heuristic or default to spam for cache purposes */
+ cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+ }
+ }
+
+ learn_res = cl->cache->check(task, cache_spam, rt);
}
if (learn_res == RSPAMD_LEARN_IGNORE) {
@@ -658,9 +738,63 @@ rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx,
continue;
}
- if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
- task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
- learned = TRUE;
+ /* Check if classifier supports multi-class learning and if we should use it */
+ if (cl->subrs->learn_class_func && cl->cfg->class_names && cl->cfg->class_names->len > 2) {
+ /* Multi-class learning: determine class name from task flags or autolearn result */
+ const char *class_name = NULL;
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ /* Find spam class name */
+ for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+ /* Look for statfile with this class that is spam */
+ GList *cur = cl->cfg->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && stcf->is_spam) {
+ class_name = check_class;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+ if (class_name) break;
+ }
+ if (!class_name) class_name = "spam"; /* fallback */
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ /* Find ham class name */
+ for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+ /* Look for statfile with this class that is ham */
+ GList *cur = cl->cfg->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && !stcf->is_spam) {
+ class_name = check_class;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+ if (class_name) break;
+ }
+ if (!class_name) class_name = "ham"; /* fallback */
+ }
+ else {
+ /* Fallback to spam/ham based on the spam parameter */
+ class_name = spam ? "spam" : "ham";
+ }
+
+ if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ else {
+ /* Binary learning: use existing function */
+ if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
}
}
@@ -759,9 +893,26 @@ rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx,
backend_found = TRUE;
if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) {
- if (!!spam != !!st->stcf->is_spam) {
- /* If we are not unlearning, then do not touch another class */
- continue;
+ /* For multiclass learning, check if this statfile has any tokens to learn */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ /* Multiclass learning: only process statfiles that have tokens set up by the classifier */
+ gboolean has_tokens = FALSE;
+ for (unsigned int k = 0; k < task->tokens->len && !has_tokens; k++) {
+ rspamd_token_t *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, k);
+ if (tok->values[id] != 0) {
+ has_tokens = TRUE;
+ }
+ }
+ if (!has_tokens) {
+ continue;
+ }
+ }
+ else {
+ /* Binary learning: use traditional spam/ham check */
+ if (!!spam != !!st->stcf->is_spam) {
+ /* If we are not unlearning, then do not touch another class */
+ continue;
+ }
}
}
@@ -870,7 +1021,24 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
if (cl->cache) {
cache_run = cl->cache->runtime(task, cl->cachecf, TRUE);
- cl->cache->learn(task, spam, cache_run);
+
+ /* For multi-class learning, determine spam boolean from class name if available */
+ gboolean cache_spam = spam; /* Default to original spam parameter */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+ cache_spam = TRUE;
+ }
+ else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+ cache_spam = FALSE;
+ }
+ else {
+ /* For other classes, use a heuristic or default to spam for cache purposes */
+ cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+ }
+ }
+
+ cl->cache->learn(task, cache_spam, cache_run);
}
}
@@ -879,6 +1047,218 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
return res;
}
+static gboolean
+rspamd_stat_classifiers_learn_class(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const char *classifier,
+ const char *class_name,
+ GError **err)
+{
+ struct rspamd_classifier *cl, *sel = NULL;
+ unsigned int i;
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
+
+ if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
+ *err == NULL) {
+ /* Do not learn twice */
+ g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ class_name);
+
+ return FALSE;
+ }
+
+ /* Check whether we have learned that file */
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
+
+ sel = cl;
+
+ /* Now check max and min tokens */
+ if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+ msg_info_task(
+ "<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->min_tokens);
+ too_small = TRUE;
+ continue;
+ }
+ else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+ msg_info_task(
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ too_large = TRUE;
+ continue;
+ }
+
+ /* Use the new multi-class learning function if available */
+ if (cl->subrs->learn_class_func) {
+ if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ else {
+ /* Fallback to binary learning with class name mapping */
+ gboolean is_spam;
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ is_spam = TRUE;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ is_spam = FALSE;
+ }
+ else {
+ /* For unknown classes with binary classifier, skip */
+ msg_info_task("skipping class '%s' for binary classifier %s",
+ class_name, cl->cfg->name);
+ continue;
+ }
+
+ if (cl->subrs->learn_spam_func(cl, task->tokens, task, is_spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ }
+
+ if (sel == NULL) {
+ if (classifier) {
+ g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
+ "with name %s",
+ classifier);
+ }
+ else {
+ g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
+ }
+
+ return FALSE;
+ }
+
+ if (!learned && err && *err == NULL) {
+ if (too_large) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%d > %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->max_tokens);
+ }
+ else if (too_small) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains less tokens than required for %s classifier: "
+ "%d < %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->min_tokens);
+ }
+ }
+
+ return learned;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn_class(struct rspamd_task *task,
+ const char *class_name,
+ lua_State *L,
+ const char *classifier,
+ unsigned int stage,
+ GError **err)
+{
+ struct rspamd_stat_ctx *st_ctx;
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+ /*
+ * We assume now that a task has been already classified before
+ * coming to learn
+ */
+ g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ msg_debug_bayes("learn class stage %d has been called for class '%s'", stage, class_name);
+
+ if (st_ctx->classifiers->len == 0) {
+ msg_debug_bayes("no classifiers defined");
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (task->message == NULL) {
+ ret = RSPAMD_STAT_PROCESS_ERROR;
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Trying to learn an empty message");
+ }
+
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+ /* Process classifiers - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
+
+ if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
+ msg_debug_bayes("cache check failed, skip learning");
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+ /* Process classifiers */
+ if (!rspamd_stat_classifiers_learn_class(st_ctx, task, classifier,
+ class_name, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when learning classifiers;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+
+ /* Process backends - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when storing data on backend;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+ /* Process backends - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+
+ task->processed_stages |= stage;
+
+ return ret;
+}
+
rspamd_stat_result_t
rspamd_stat_learn(struct rspamd_task *task,
gboolean spam, lua_State *L, const char *classifier, unsigned int stage,
@@ -1039,12 +1419,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
if (mres) {
if (mres->score > rspamd_task_get_required_score(task, mres)) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+ rspamd_task_set_autolearn_class(task, "spam");
ret = TRUE;
}
else if (mres->score < 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
}
@@ -1076,12 +1455,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
if (mres) {
if (mres->score >= spam_score) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+ rspamd_task_set_autolearn_class(task, "spam");
ret = TRUE;
}
else if (mres->score <= ham_score) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
}
@@ -1117,11 +1495,16 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
/* We can have immediate results */
if (lua_ret) {
if (strcmp(lua_ret, "ham") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
else if (strcmp(lua_ret, "spam") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ rspamd_task_set_autolearn_class(task, "spam");
+ ret = TRUE;
+ }
+ else {
+ /* Multi-class: any other class name */
+ rspamd_task_set_autolearn_class(task, lua_ret);
ret = TRUE;
}
}
@@ -1139,79 +1522,138 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
}
}
else if (ucl_object_type(obj) == UCL_OBJECT) {
- /* Try to find autolearn callback */
- if (cl->autolearn_cbref == 0) {
- /* We don't have preprocessed cb id, so try to get it */
- if (!rspamd_lua_require_function(L, "lua_bayes_learn",
- "autolearn")) {
- msg_err_task("cannot get autolearn library from "
- "`lua_bayes_learn`");
- }
- else {
- cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ /* Check if this is a multi-class autolearn configuration */
+ const ucl_object_t *multiclass_obj = ucl_object_lookup(obj, "multiclass");
+
+ if (multiclass_obj && ucl_object_type(multiclass_obj) == UCL_OBJECT) {
+ /* Multi-class threshold-based autolearn */
+ const ucl_object_t *thresholds_obj = ucl_object_lookup(multiclass_obj, "thresholds");
+
+ if (thresholds_obj && ucl_object_type(thresholds_obj) == UCL_OBJECT) {
+ /* Iterate through class thresholds */
+ ucl_object_iter_t it = NULL;
+ const ucl_object_t *class_obj;
+ const char *class_name;
+
+ while ((class_obj = ucl_object_iterate(thresholds_obj, &it, true))) {
+ class_name = ucl_object_key(class_obj);
+
+ if (class_name && ucl_object_type(class_obj) == UCL_ARRAY && class_obj->len == 2) {
+ /* [min_score, max_score] for this class */
+ const ucl_object_t *min_elt = ucl_array_find_index(class_obj, 0);
+ const ucl_object_t *max_elt = ucl_array_find_index(class_obj, 1);
+
+ if ((ucl_object_type(min_elt) == UCL_FLOAT || ucl_object_type(min_elt) == UCL_INT) &&
+ (ucl_object_type(max_elt) == UCL_FLOAT || ucl_object_type(max_elt) == UCL_INT)) {
+
+ double min_score = ucl_object_todouble(min_elt);
+ double max_score = ucl_object_todouble(max_elt);
+
+ if (mres && mres->score >= min_score && mres->score <= max_score) {
+ rspamd_task_set_autolearn_class(task, class_name);
+ ret = TRUE;
+ msg_debug_bayes("multiclass autolearn: score %.2f matches class '%s' [%.2f, %.2f]",
+ mres->score, class_name, min_score, max_score);
+ break; /* Stop at first matching class */
+ }
+ }
+ }
+ }
}
}
-
- if (cl->autolearn_cbref != -1) {
- lua_pushcfunction(L, &rspamd_lua_traceback);
- err_idx = lua_gettop(L);
- lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
-
- ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
- *ptask = task;
- rspamd_lua_setclass(L, rspamd_task_classname, -1);
- /* Push the whole object as well */
- ucl_object_push_lua(L, obj, true);
-
- if (lua_pcall(L, 2, 1, err_idx) != 0) {
- msg_err_task("call to autolearn script failed: "
- "%s",
- lua_tostring(L, -1));
+ else {
+ /* Try to find autolearn callback */
+ if (cl->autolearn_cbref == 0) {
+ /* We don't have preprocessed cb id, so try to get it */
+ if (!rspamd_lua_require_function(L, "lua_bayes_learn",
+ "autolearn")) {
+ msg_err_task("cannot get autolearn library from "
+ "`lua_bayes_learn`");
+ }
+ else {
+ cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
}
- else {
- lua_ret = lua_tostring(L, -1);
- if (lua_ret) {
- if (strcmp(lua_ret, "ham") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
- ret = TRUE;
- }
- else if (strcmp(lua_ret, "spam") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
- ret = TRUE;
+ if (cl->autolearn_cbref != -1) {
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+ lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, rspamd_task_classname, -1);
+ /* Push the whole object as well */
+ ucl_object_push_lua(L, obj, true);
+
+ if (lua_pcall(L, 2, 1, err_idx) != 0) {
+ msg_err_task("call to autolearn script failed: "
+ "%s",
+ lua_tostring(L, -1));
+ }
+ else {
+ lua_ret = lua_tostring(L, -1);
+
+ if (lua_ret) {
+ if (strcmp(lua_ret, "ham") == 0) {
+ rspamd_task_set_autolearn_class(task, "ham");
+ ret = TRUE;
+ }
+ else if (strcmp(lua_ret, "spam") == 0) {
+ rspamd_task_set_autolearn_class(task, "spam");
+ ret = TRUE;
+ }
+ else {
+ /* Multi-class: any other class name */
+ rspamd_task_set_autolearn_class(task, lua_ret);
+ ret = TRUE;
+ }
}
}
- }
- lua_settop(L, err_idx - 1);
+ lua_settop(L, err_idx - 1);
+ }
}
- }
- if (ret) {
- /* Do not autolearn if we have this symbol already */
- if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
- ret = FALSE;
- task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
- RSPAMD_TASK_FLAG_LEARN_SPAM);
- }
- else if (mres != NULL) {
- if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
- msg_info_task("<%s>: autolearn ham for classifier "
- "'%s' as message's "
- "score is negative: %.2f",
- MESSAGE_FIELD(task, message_id), cl->cfg->name,
- mres->score);
- }
- else {
- msg_info_task("<%s>: autolearn spam for classifier "
- "'%s' as message's "
- "action is reject, score: %.2f",
- MESSAGE_FIELD(task, message_id), cl->cfg->name,
- mres->score);
+ if (ret) {
+ /* Do not autolearn if we have this symbol already */
+ if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
+ ret = FALSE;
+ task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
+ RSPAMD_TASK_FLAG_LEARN_SPAM |
+ RSPAMD_TASK_FLAG_LEARN_CLASS);
+ /* Clear the autolearn class from mempool */
+ rspamd_mempool_set_variable(task->task_pool, "autolearn_class", NULL, NULL);
}
+ else if (mres != NULL) {
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "ham") == 0) {
+ msg_info_task("<%s>: autolearn ham for classifier "
+ "'%s' as message's "
+ "score is negative: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+ else if (strcmp(autolearn_class, "spam") == 0) {
+ msg_info_task("<%s>: autolearn spam for classifier "
+ "'%s' as message's "
+ "action is reject, score: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+ else {
+ msg_info_task("<%s>: autolearn class '%s' for classifier "
+ "'%s', score: %.2f",
+ MESSAGE_FIELD(task, message_id), autolearn_class,
+ cl->cfg->name, mres->score);
+ }
+ }
- task->classifier = cl->cfg->name;
- break;
+ task->classifier = cl->cfg->name;
+ break;
+ }
}
}
}
diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua
index 44ff9dafa..0d78f2272 100644
--- a/src/plugins/lua/bayes_expiry.lua
+++ b/src/plugins/lua/bayes_expiry.lua
@@ -41,32 +41,38 @@ local template = {}
local function check_redis_classifier(cls, cfg)
-- Skip old classifiers
if cls.new_schema then
- local symbol_spam, symbol_ham
+ local class_symbols = {}
+ local class_labels = {}
local expiry = (cls.expiry or cls.expire)
if type(expiry) == 'table' then
expiry = expiry[1]
end
- -- Load symbols from statfiles
+ -- Extract class_labels mapping from classifier config
+ if cls.class_labels then
+ class_labels = cls.class_labels
+ end
+ -- Load symbols from statfiles for multi-class support
local function check_statfile_table(tbl, def_sym)
local symbol = tbl.symbol or def_sym
-
- local spam
- if tbl.spam then
- spam = tbl.spam
- else
- if string.match(symbol:upper(), 'SPAM') then
- spam = true
+ local class_name = tbl.class
+
+ -- Handle legacy spam/ham detection for backward compatibility
+ if not class_name then
+ if tbl.spam ~= nil then
+ class_name = tbl.spam and 'spam' or 'ham'
+ elseif string.match(tostring(symbol):upper(), 'SPAM') then
+ class_name = 'spam'
+ elseif string.match(tostring(symbol):upper(), 'HAM') then
+ class_name = 'ham'
else
- spam = false
+ class_name = def_sym
end
end
- if spam then
- symbol_spam = symbol
- else
- symbol_ham = symbol
+ if class_name then
+ class_symbols[class_name] = symbol
end
end
@@ -87,10 +93,9 @@ local function check_redis_classifier(cls, cfg)
end
end
- if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then
+ if next(class_symbols) == nil or type(expiry) ~= 'number' then
logger.debugm(N, rspamd_config,
- 'disable expiry for classifier %s: no expiry %s',
- symbol_spam, cls)
+ 'disable expiry for classifier: no class symbols or expiry configured')
return
end
-- Now try to load redis_params if needed
@@ -108,17 +113,16 @@ local function check_redis_classifier(cls, cfg)
end
if redis_params['read_only'] then
- logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration',
- symbol_spam)
+ logger.infox(rspamd_config, 'disable expiry for classifier: read only redis configuration')
return
end
- logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry",
- symbol_spam, symbol_ham, expiry)
+ logger.debugm(N, rspamd_config, "enabled expiry for classes %s -> %s expiry",
+ table.concat(lutil.keys(class_symbols), ', '), expiry)
table.insert(settings.classifiers, {
- symbol_spam = symbol_spam,
- symbol_ham = symbol_ham,
+ class_symbols = class_symbols,
+ class_labels = class_labels,
redis_params = redis_params,
expiry = expiry
})
@@ -249,12 +253,11 @@ local expiry_script = [[
local keys = ret[2]
local tokens = {}
- -- Tokens occurrences distribution counters
+ -- Dynamic occurrence tracking for all classes
local occur = {
- ham = {},
- spam = {},
total = {}
}
+ local classes_found = {}
-- Expiry step statistics counters
local nelts, extended, discriminated, sum, sum_squares, common, significant,
@@ -264,24 +267,44 @@ local expiry_script = [[
for _,key in ipairs(keys) do
local t = redis.call('TYPE', key)["ok"]
if t == 'hash' then
- local values = redis.call('HMGET', key, 'H', 'S')
- local ham = tonumber(values[1]) or 0
- local spam = tonumber(values[2]) or 0
+ -- Get all hash fields to support multi-class
+ local hash_data = redis.call('HGETALL', key)
+ local class_counts = {}
+ local total = 0
local ttl = redis.call('TTL', key)
+
+ -- Parse hash data into class counts
+ for i = 1, #hash_data, 2 do
+ local class_label = hash_data[i]
+ local count = tonumber(hash_data[i + 1]) or 0
+ class_counts[class_label] = count
+ total = total + count
+
+ -- Track classes we've seen
+ if not classes_found[class_label] then
+ classes_found[class_label] = true
+ occur[class_label] = {}
+ end
+ end
+
tokens[key] = {
- ham,
- spam,
- ttl
+ class_counts = class_counts,
+ total = total,
+ ttl = ttl
}
- local total = spam + ham
+
sum = sum + total
sum_squares = sum_squares + total * total
nelts = nelts + 1
- for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do
- if tonumber(v) > 19 then v = 20 end
- occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1
+ -- Update occurrence counters for all classes and total
+ for class_label, count in pairs(class_counts) do
+ local bucket = count > 19 and 20 or count
+ occur[class_label][bucket] = (occur[class_label][bucket] or 0) + 1
end
+
+ local total_bucket = total > 19 and 20 or total
+ occur.total[total_bucket] = (occur.total[total_bucket] or 0) + 1
end
end
@@ -293,9 +316,10 @@ local expiry_script = [[
end
for key,token in pairs(tokens) do
- local ham, spam, ttl = token[1], token[2], tonumber(token[3])
+ local class_counts = token.class_counts
+ local total = token.total
+ local ttl = tonumber(token.ttl)
local threshold = mean
- local total = spam + ham
local function set_ttl()
if expire < 0 then
@@ -310,14 +334,39 @@ local expiry_script = [[
return 0
end
- if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then
+ -- Check if token is common (balanced across classes)
+ local is_common = false
+ if total == 0 then
+ is_common = true
+ else
+ -- For multi-class, check if any class dominates significantly
+ local max_count = 0
+ for _, count in pairs(class_counts) do
+ if count > max_count then
+ max_count = count
+ end
+ end
+ -- Token is common if no class has more than (1 - epsilon) of total
+ is_common = (max_count / total) <= (1 - ${epsilon_common})
+ end
+
+ if is_common then
common = common + 1
if ttl > ${common_ttl} then
discriminated = discriminated + 1
redis.call('EXPIRE', key, ${common_ttl})
end
elseif total >= threshold and total > 0 then
- if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
+ -- Check if any class is significant
+ local is_significant = false
+ for _, count in pairs(class_counts) do
+ if count / total > ${significant_factor} then
+ is_significant = true
+ break
+ end
+ end
+
+ if is_significant then
significant = significant + 1
if ttl ~= -1 then
redis.call('PERSIST', key)
@@ -361,33 +410,50 @@ local expiry_script = [[
redis.call('DEL', lock_key)
local occ_distr = {}
- for _,cl in pairs({'ham', 'spam', 'total'}) do
+
+ -- Process all classes found plus total
+ local all_classes = {'total'}
+ for class_label in pairs(classes_found) do
+ table.insert(all_classes, class_label)
+ end
+
+ for _, cl in ipairs(all_classes) do
local occur_key = pattern_sha1 .. '_occurrence_' .. cl
if cursor ~= 0 then
- local n
- for i,v in ipairs(redis.call('HGETALL', occur_key)) do
- if i % 2 == 1 then
- n = tonumber(v)
- else
- occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v
+ local existing_data = redis.call('HGETALL', occur_key)
+ if #existing_data > 0 then
+ for i = 1, #existing_data, 2 do
+ local bucket = tonumber(existing_data[i])
+ local count = tonumber(existing_data[i + 1])
+ if occur[cl] and occur[cl][bucket] then
+ occur[cl][bucket] = occur[cl][bucket] + count
+ elseif occur[cl] then
+ occur[cl][bucket] = count
+ end
end
end
- local str = ''
- if occur[cl][0] ~= nil then
- str = '0:' .. occur[cl][0] .. ','
- end
- for k,v in ipairs(occur[cl]) do
- if k == 20 then k = '>19' end
- str = str .. k .. ':' .. v .. ','
+ if occur[cl] and next(occur[cl]) then
+ local str = ''
+ if occur[cl][0] then
+ str = '0:' .. occur[cl][0] .. ','
+ end
+ for k = 1, 20 do
+ if occur[cl][k] then
+ local label = k == 20 and '>19' or tostring(k)
+ str = str .. label .. ':' .. occur[cl][k] .. ','
+ end
+ end
+ table.insert(occ_distr, cl .. '=' .. str)
+ else
+ table.insert(occ_distr, cl .. '=no_data')
end
- table.insert(occ_distr, str)
else
redis.call('DEL', occur_key)
end
- if next(occur[cl]) ~= nil then
+ if occur[cl] and next(occur[cl]) then
redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl])))
end
end
@@ -446,8 +512,8 @@ local function expire_step(cls, ev_base, worker)
'%s infrequent (%s %s), %s mean, %s std',
lutil.unpack(d))
if cycle then
- for i, cl in ipairs({ 'in ham', 'in spam', 'total' }) do
- logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i])
+ for _, distr_info in ipairs(occ_distr) do
+ logger.infox(rspamd_config, 'tokens occurrences: {%s}', distr_info)
end
end
end
diff --git a/test/functional/cases/110_statistics/300-multiclass-redis.robot b/test/functional/cases/110_statistics/300-multiclass-redis.robot
new file mode 100644
index 000000000..278f7e0a0
--- /dev/null
+++ b/test/functional/cases/110_statistics/300-multiclass-redis.robot
@@ -0,0 +1,42 @@
+*** Settings ***
+Documentation Multiclass Bayes Classification Tests with Redis Backend
+Suite Setup Rspamd Redis Setup
+Suite Teardown Rspamd Redis Teardown
+Test Setup Set Test Hash Documentation
+Resource multiclass_lib.robot
+
+*** Variables ***
+${RSPAMD_REDIS_SERVER} ${RSPAMD_REDIS_ADDR}:${RSPAMD_REDIS_PORT}
+${RSPAMD_STATS_HASH} siphash
+${CONFIG} ${RSPAMD_TESTDIR}/configs/multiclass_bayes.conf
+
+*** Test Cases ***
+Multiclass Basic Learning and Classification
+ [Documentation] Test basic multiclass learning and classification
+ [Tags] multiclass basic learning
+ Multiclass Basic Learn Test
+
+Multiclass Legacy Compatibility
+ [Documentation] Test that old learn_spam/learn_ham commands still work
+ [Tags] multiclass compatibility legacy
+ Multiclass Legacy Compatibility Test
+
+Multiclass Relearn
+ [Documentation] Test reclassifying messages to different classes
+ [Tags] multiclass relearn
+ Multiclass Relearn Test
+
+Multiclass Cross-Class Learning
+ [Documentation] Test learning message as different class than expected
+ [Tags] multiclass cross-learn
+ Multiclass Cross-Learn Test
+
+Multiclass Unlearn
+ [Documentation] Test unlearning (learning message as different class)
+ [Tags] multiclass unlearn
+ Multiclass Unlearn Test
+
+Multiclass Statistics
+ [Documentation] Test that statistics show all class information
+ [Tags] multiclass statistics
+ Multiclass Stats Test \ No newline at end of file
diff --git a/test/functional/cases/110_statistics/320-multiclass-peruser.robot b/test/functional/cases/110_statistics/320-multiclass-peruser.robot
new file mode 100644
index 000000000..e8ca34616
--- /dev/null
+++ b/test/functional/cases/110_statistics/320-multiclass-peruser.robot
@@ -0,0 +1,31 @@
+*** Settings ***
+Suite Setup Rspamd Redis Setup
+Suite Teardown Rspamd Redis Teardown
+Test Setup Set Test Hash Documentation
+Resource multiclass_lib.robot
+
+*** Variables ***
+${CONFIG} ${RSPAMD_TESTDIR}/configs/multiclass_bayes.conf
+${REDIS_SCOPE} Suite
+${RSPAMD_REDIS_SERVER} ${RSPAMD_REDIS_ADDR}:${RSPAMD_REDIS_PORT}
+${RSPAMD_SCOPE} Suite
+${RSPAMD_STATS_BACKEND} redis
+${RSPAMD_STATS_HASH} null
+${RSPAMD_STATS_KEY} null
+${RSPAMD_STATS_PER_USER} true
+
+*** Test Cases ***
+Multiclass Per-User Basic Learn Test
+ Multiclass Basic Learn Test test@example.com
+
+Multiclass Per-User Legacy Compatibility Test
+ Multiclass Legacy Compatibility Test test@example.com
+
+Multiclass Per-User Relearn Test
+ Multiclass Relearn Test test@example.com
+
+Multiclass Per-User Cross-Learn Test
+ Multiclass Cross-Learn Test test@example.com
+
+Multiclass Per-User Unlearn Test
+ Multiclass Unlearn Test test@example.com \ No newline at end of file
diff --git a/test/functional/cases/110_statistics/multiclass_lib.robot b/test/functional/cases/110_statistics/multiclass_lib.robot
new file mode 100644
index 000000000..9f70e05fb
--- /dev/null
+++ b/test/functional/cases/110_statistics/multiclass_lib.robot
@@ -0,0 +1,169 @@
+*** Settings ***
+Library OperatingSystem
+Resource lib.robot
+
+*** Variables ***
+${CONFIG} ${RSPAMD_TESTDIR}/configs/multiclass_bayes.conf
+${MESSAGE_HAM} ${RSPAMD_TESTDIR}/messages/ham.eml
+${MESSAGE_SPAM} ${RSPAMD_TESTDIR}/messages/spam_message.eml
+${MESSAGE_NEWSLETTER} ${RSPAMD_TESTDIR}/messages/newsletter.eml
+${REDIS_SCOPE} Suite
+${RSPAMD_REDIS_SERVER} null
+${RSPAMD_SCOPE} Suite
+${RSPAMD_STATS_BACKEND} redis
+${RSPAMD_STATS_HASH} null
+${RSPAMD_STATS_KEY} null
+${RSPAMD_STATS_PER_USER} ${EMPTY}
+
+*** Keywords ***
+Learn Multiclass
+ [Arguments] ${user} ${class} ${message}
+ # Extract filename from message path for queue-id
+ ${path} ${filename} = Split Path ${message}
+ IF "${user}"
+ ${result} = Run Rspamc -d ${user} -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} learn_class:${class} ${message}
+ ELSE
+ ${result} = Run Rspamc -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} learn_class:${class} ${message}
+ END
+ Check Rspamc ${result}
+
+Learn Multiclass Legacy
+ [Arguments] ${user} ${class} ${message}
+ # Test backward compatibility with old learn_spam/learn_ham commands
+ # Extract filename from message path for queue-id
+ ${path} ${filename} = Split Path ${message}
+ IF "${user}"
+ ${result} = Run Rspamc -d ${user} -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} learn_${class} ${message}
+ ELSE
+ ${result} = Run Rspamc -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} learn_${class} ${message}
+ END
+ Check Rspamc ${result}
+
+Multiclass Basic Learn Test
+ [Arguments] ${user}=${EMPTY}
+ Set Suite Variable ${RSPAMD_STATS_LEARNTEST} 0
+ Set Test Variable ${kwargs} &{EMPTY}
+ IF "${user}"
+ Set To Dictionary ${kwargs} Deliver-To=${user}
+ END
+
+ # Learn all classes
+ Learn Multiclass ${user} spam ${MESSAGE_SPAM}
+ Learn Multiclass ${user} ham ${MESSAGE_HAM}
+ Learn Multiclass ${user} newsletter ${MESSAGE_NEWSLETTER}
+
+ # Test classification
+ Scan File ${MESSAGE_SPAM} &{kwargs}
+ Expect Symbol BAYES_SPAM
+
+ Scan File ${MESSAGE_HAM} &{kwargs}
+ Expect Symbol BAYES_HAM
+
+ Scan File ${MESSAGE_NEWSLETTER} &{kwargs}
+ Expect Symbol BAYES_NEWSLETTER
+
+ Set Suite Variable ${RSPAMD_STATS_LEARNTEST} 1
+
+Multiclass Legacy Compatibility Test
+ [Arguments] ${user}=${EMPTY}
+ Set Test Variable ${kwargs} &{EMPTY}
+ IF "${user}"
+ Set To Dictionary ${kwargs} Deliver-To=${user}
+ END
+
+ # Test legacy learn_spam and learn_ham commands still work
+ Learn Multiclass Legacy ${user} spam ${MESSAGE_SPAM}
+ Learn Multiclass Legacy ${user} ham ${MESSAGE_HAM}
+
+ # Should still classify correctly
+ Scan File ${MESSAGE_SPAM} &{kwargs}
+ Expect Symbol BAYES_SPAM
+
+ Scan File ${MESSAGE_HAM} &{kwargs}
+ Expect Symbol BAYES_HAM
+
+Multiclass Relearn Test
+ [Arguments] ${user}=${EMPTY}
+ IF ${RSPAMD_STATS_LEARNTEST} == 0
+ Fail "Learn test was not run"
+ END
+
+ Set Test Variable ${kwargs} &{EMPTY}
+ IF "${user}"
+ Set To Dictionary ${kwargs} Deliver-To=${user}
+ END
+
+ # Relearn spam message as ham
+ Learn Multiclass ${user} ham ${MESSAGE_SPAM}
+
+ # Should now classify as ham or at least not spam
+ Scan File ${MESSAGE_SPAM} &{kwargs}
+ ${pass} = Run Keyword And Return Status Expect Symbol BAYES_HAM
+ IF ${pass}
+ Pass Execution Successfully reclassified spam as ham
+ END
+ Do Not Expect Symbol BAYES_SPAM
+
+Multiclass Cross-Learn Test
+ [Arguments] ${user}=${EMPTY}
+ Set Test Variable ${kwargs} &{EMPTY}
+ IF "${user}"
+ Set To Dictionary ${kwargs} Deliver-To=${user}
+ END
+
+ # Learn newsletter message as ham to test cross-class learning
+ Learn Multiclass ${user} ham ${MESSAGE_NEWSLETTER}
+
+ # Should classify as ham, not newsletter (since we trained it as ham)
+ Scan File ${MESSAGE_NEWSLETTER} &{kwargs}
+ Expect Symbol BAYES_HAM
+ Do Not Expect Symbol BAYES_NEWSLETTER
+
+Multiclass Unlearn Test
+ [Arguments] ${user}=${EMPTY}
+ Set Test Variable ${kwargs} &{EMPTY}
+ IF "${user}"
+ Set To Dictionary ${kwargs} Deliver-To=${user}
+ END
+
+ # First learn spam
+ Learn Multiclass ${user} spam ${MESSAGE_SPAM}
+ Scan File ${MESSAGE_SPAM} &{kwargs}
+ Expect Symbol BAYES_SPAM
+
+ # Then unlearn spam (learn as ham)
+ Learn Multiclass ${user} ham ${MESSAGE_SPAM}
+
+ # Should no longer classify as spam
+ Scan File ${MESSAGE_SPAM} &{kwargs}
+ Do Not Expect Symbol BAYES_SPAM
+
+Check Multiclass Results
+ [Arguments] ${result} ${expected_class}
+ # Check that scan result contains expected class information
+ Should Contain ${result.stdout} BAYES_${expected_class.upper()}
+ # Check for multiclass result format [class_name]
+ Should Match Regexp ${result.stdout} BAYES_${expected_class.upper()}.*\\[${expected_class}\\]
+
+Multiclass Stats Test
+ # Check that rspamc stat shows learning counts for all classes
+ ${result} = Run Rspamc -h ${RSPAMD_LOCAL_ADDR}:${RSPAMD_PORT_CONTROLLER} stat
+ # Don't use Check Rspamc for stat command as it expects JSON success format
+ Should Be Equal As Integers ${result.rc} 0
+
+ # Should show statistics for all classes
+ Should Contain ${result.stdout} BAYES_SPAM
+ Should Contain ${result.stdout} BAYES_HAM
+ Should Contain ${result.stdout} BAYES_NEWSLETTER
+
+Multiclass Configuration Migration Test
+ # Test that old binary config can be automatically migrated
+ Set Test Variable ${binary_config} ${RSPAMD_TESTDIR}/configs/stats.conf
+
+ # Start with binary config
+ ${result} = Run Rspamc --config ${binary_config} stat
+ Check Rspamc ${result}
+
+ # Should show deprecation warning but work
+ Should Contain ${result.stderr} deprecated ignore_case=True
+
diff --git a/test/functional/configs/multiclass_bayes.conf b/test/functional/configs/multiclass_bayes.conf
new file mode 100644
index 000000000..278aeeee9
--- /dev/null
+++ b/test/functional/configs/multiclass_bayes.conf
@@ -0,0 +1,129 @@
+options = {
+ filters = ["spf", "dkim", "regexp"]
+ url_tld = "{= env.TESTDIR =}/../lua/unit/test_tld.dat"
+ pidfile = "{= env.TMPDIR =}/rspamd.pid"
+ dns {
+ retransmits = 10;
+ timeout = 2s;
+ fake_records = [{
+ name = "example.net";
+ type = txt;
+ replies = ["v=spf1 -all"];
+ }]
+ }
+}
+
+logging = {
+ type = "file",
+ level = "debug"
+ filename = "{= env.TMPDIR =}/rspamd.log"
+}
+
+metric = {
+ name = "default",
+ actions = {
+ reject = 100500,
+ }
+ unknown_weight = 1
+}
+
+worker {
+ type = normal
+ bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_NORMAL =}"
+ count = 1
+ keypair {
+ pubkey = "{= env.KEY_PUB1 =}";
+ privkey = "{= env.KEY_PVT1 =}";
+ }
+ task_timeout = 60s;
+}
+
+worker {
+ type = controller
+ bind_socket = "{= env.LOCAL_ADDR =}:{= env.PORT_CONTROLLER =}"
+ count = 1
+ keypair {
+ pubkey = "{= env.KEY_PUB1 =}";
+ privkey = "{= env.KEY_PVT1 =}";
+ }
+ secure_ip = ["127.0.0.1", "::1"];
+ stats_path = "{= env.TMPDIR =}/stats.ucl";
+}
+
+# Multi-class Bayes classifier configuration
+classifier {
+ languages_enabled = true;
+ tokenizer {
+ name = "osb";
+ hash = {= env.STATS_HASH =};
+ key = {= env.STATS_KEY =};
+ }
+ backend = "{= env.STATS_BACKEND =}";
+
+ # Multi-class statfiles
+ statfile {
+ class = "spam";
+ symbol = BAYES_SPAM;
+ server = {= env.REDIS_SERVER =}
+ }
+ statfile {
+ class = "ham";
+ symbol = BAYES_HAM;
+ server = {= env.REDIS_SERVER =}
+ }
+ statfile {
+ class = "newsletter";
+ symbol = BAYES_NEWSLETTER;
+ server = {= env.REDIS_SERVER =}
+ }
+
+ # Backend class labels for Redis
+ class_labels = {
+ "spam" = "S";
+ "ham" = "H";
+ "newsletter" = "N";
+ }
+
+ cache {
+ server = {= env.REDIS_SERVER =}
+ }
+
+ # Multi-class autolearn configuration
+ autolearn = {
+ classes = {
+ spam = {
+ threshold = 15.0;
+ verdict_mapping = { spam = true };
+ };
+ ham = {
+ threshold = -5.0;
+ verdict_mapping = { ham = true };
+ };
+ newsletter = {
+ symbols = ["NEWSLETTER_HEADER", "BULK_MAIL", "UNSUBSCRIBE_LINK"];
+ threshold = 8.0;
+ };
+ };
+
+ check_balance = true;
+ max_class_ratio = 0.6;
+ skip_threshold = 0.95;
+ }
+
+ min_learns = 1;
+ min_tokens = 1;
+ min_token_hits = 1;
+ min_prob_strength = 0.05;
+
+ {% if env.STATS_PER_USER ~= '' %}
+ per_user = <<EOD
+return function(task)
+ return task:get_principal_recipient()
+end
+EOD;
+ {% endif %}
+}
+
+lua = "{= env.TESTDIR =}/lua/test_coverage.lua";
+
+settings {}
diff --git a/test/functional/lib/rspamd.robot b/test/functional/lib/rspamd.robot
index 5d23e3ceb..f61998f46 100644
--- a/test/functional/lib/rspamd.robot
+++ b/test/functional/lib/rspamd.robot
@@ -419,10 +419,23 @@ Run Nginx
${nginx_log} = Get File ${RSPAMD_TMPDIR}/nginx.log
Log ${nginx_log}
+Set Test Hash Documentation
+ ${log_tag} = Evaluate __import__('hashlib').md5('${TEST NAME}'.encode()).hexdigest()[:8]
+ Log TEST CONTEXT: [${log_tag}] ${TEST NAME} console=True
+
Run Rspamc
[Arguments] @{args}
- ${result} = Run Process ${RSPAMC} -t 60 --header Queue-ID\=${TEST NAME}
- ... @{args} env:LD_LIBRARY_PATH=${RSPAMD_TESTDIR}/../../contrib/aho-corasick
+ ${log_tag} = Evaluate __import__('hashlib').md5('${TEST NAME}'.encode()).hexdigest()[:8]
+ # Check if --queue-id is already provided in the arguments
+ ${args_str} = Evaluate ' '.join(@{args})
+ ${has_queue_id} = Evaluate '--queue-id' in '${args_str}'
+ IF ${has_queue_id}
+ ${result} = Run Process ${RSPAMC} -t 60 --log-tag ${log_tag}
+ ... @{args} env:LD_LIBRARY_PATH=${RSPAMD_TESTDIR}/../../contrib/aho-corasick
+ ELSE
+ ${result} = Run Process ${RSPAMC} -t 60 --queue-id ${TEST NAME} --log-tag ${log_tag}
+ ... @{args} env:LD_LIBRARY_PATH=${RSPAMD_TESTDIR}/../../contrib/aho-corasick
+ END
Log ${result.stdout}
[Return] ${result}
diff --git a/test/functional/messages/newsletter.eml b/test/functional/messages/newsletter.eml
new file mode 100644
index 000000000..93c996956
--- /dev/null
+++ b/test/functional/messages/newsletter.eml
@@ -0,0 +1,50 @@
+From: "Marketing Team" <newsletter@example.com>
+To: user@example.org
+Subject: 🎉 Monthly Newsletter - Exclusive Deals & Product Updates!
+Date: Thu, 21 Jul 2023 10:00:00 +0000
+Message-ID: <newsletter-123@example.com>
+MIME-Version: 1.0
+Content-Type: text/html; charset=utf-8
+List-Unsubscribe: <https://example.com/unsubscribe?id=123>
+Precedence: bulk
+X-Mailer: MailChimp/Pro 12.345
+
+<!DOCTYPE html>
+<html>
+<head>
+ <meta charset="utf-8">
+ <title>Monthly Newsletter</title>
+</head>
+<body>
+ <h1>🎉 Exclusive Monthly Offers!</h1>
+
+ <p>Dear Valued Subscriber,</p>
+
+ <p>This month we're excited to bring you our <strong>BIGGEST SALE</strong> of the year!</p>
+
+ <h2>🔥 Hot Deals This Month:</h2>
+ <ul>
+ <li>50% OFF all premium products</li>
+ <li>FREE shipping on orders over $50</li>
+ <li>Buy 2 Get 1 FREE on selected items</li>
+ </ul>
+
+ <p><a href="https://example.com/shop?utm_source=newsletter&utm_campaign=monthly">SHOP NOW</a></p>
+
+ <h2>📱 New Product Launch</h2>
+ <p>Check out our revolutionary new gadget that everyone is talking about!</p>
+
+ <h2>🎁 Refer a Friend</h2>
+ <p>Share this newsletter and both you and your friend get $10 credit!</p>
+
+ <hr>
+
+ <p><small>
+ You're receiving this because you subscribed to our newsletter.<br>
+ <a href="https://example.com/unsubscribe?id=123">Unsubscribe here</a> |
+ <a href="https://example.com/preferences">Update preferences</a><br>
+ Marketing Team, Example Corp<br>
+ 123 Business St, City, State 12345
+ </small></p>
+</body>
+</html> \ No newline at end of file
diff --git a/test/functional/messages/transactional.eml b/test/functional/messages/transactional.eml
new file mode 100644
index 000000000..e227aaa77
--- /dev/null
+++ b/test/functional/messages/transactional.eml
@@ -0,0 +1,18 @@
+From: noreply@example.com
+To: user@example.org
+Subject: Password Reset Request
+Date: Thu, 21 Jul 2023 11:00:00 +0000
+Message-ID: <pwd-reset-456@example.com>
+MIME-Version: 1.0
+Content-Type: text/plain
+
+Hello,
+
+You have requested a password reset for your account.
+
+Click here to reset your password: https://example.com/reset?token=abc123
+
+This link expires in 24 hours.
+
+Best regards,
+Security Team \ No newline at end of file