diff options
Diffstat (limited to 'src/libstat')
-rw-r--r-- | src/libstat/CMakeLists.txt | 33 | ||||
-rw-r--r-- | src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md | 451 | ||||
-rw-r--r-- | src/libstat/backends/cdb_backend.cxx | 13 | ||||
-rw-r--r-- | src/libstat/backends/mmaped_file.c | 10 | ||||
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 552 | ||||
-rw-r--r-- | src/libstat/backends/sqlite3_backend.c | 7 | ||||
-rw-r--r-- | src/libstat/classifiers/bayes.c | 652 | ||||
-rw-r--r-- | src/libstat/classifiers/classifiers.h | 14 | ||||
-rw-r--r-- | src/libstat/learn_cache/redis_cache.cxx | 84 | ||||
-rw-r--r-- | src/libstat/stat_api.h | 81 | ||||
-rw-r--r-- | src/libstat/stat_config.c | 11 | ||||
-rw-r--r-- | src/libstat/stat_process.c | 654 | ||||
-rw-r--r-- | src/libstat/tokenizers/custom_tokenizer.h | 177 | ||||
-rw-r--r-- | src/libstat/tokenizers/osb.c | 9 | ||||
-rw-r--r-- | src/libstat/tokenizers/rspamd_tokenizer_types.h | 89 | ||||
-rw-r--r-- | src/libstat/tokenizers/tokenizer_manager.c | 500 | ||||
-rw-r--r-- | src/libstat/tokenizers/tokenizers.c | 202 | ||||
-rw-r--r-- | src/libstat/tokenizers/tokenizers.h | 33 |
18 files changed, 3179 insertions, 393 deletions
diff --git a/src/libstat/CMakeLists.txt b/src/libstat/CMakeLists.txt index 64d572a57..eddf64e49 100644 --- a/src/libstat/CMakeLists.txt +++ b/src/libstat/CMakeLists.txt @@ -1,25 +1,26 @@ # Librspamdserver -SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c - ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c) +SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c + ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c) -SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c - ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c) +SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c + ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizer_manager.c + ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c) -SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c - ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c) +SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c + ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c) -SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c - ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c - ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx - ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx - ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) +SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) -SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c +SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx) SET(RSPAMD_STAT ${LIBSTATSRC} - ${TOKENIZERSSRC} - ${CLASSIFIERSSRC} - ${BACKENDSSRC} - ${CACHESSRC} PARENT_SCOPE) + ${TOKENIZERSSRC} + ${CLASSIFIERSSRC} + ${BACKENDSSRC} + ${CACHESSRC} PARENT_SCOPE) diff --git a/src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md b/src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md new file mode 100644 index 000000000..dc8352374 --- /dev/null +++ b/src/libstat/MULTICLASS_BAYES_ARCHITECTURE.md @@ -0,0 +1,451 @@ +# Rspamd Multiclass Bayes Architecture + +## Overview + +This document describes the complete data flow for the multiclass Bayes classification system in Rspamd, covering the interaction between C++ core, Lua scripts, Redis backend, and the classification pipeline. + +## High-Level Data Flow + +``` +[Task Processing] → [Tokenization] → [Redis Backend] → [Lua Scripts] → [Redis Scripts] → [Results] → [Classification] +``` + +## 1. Classification Pipeline Entry Point + +### 1.1 Task Processing Start + +```c +// src/libstat/stat_process.c +rspamd_stat_classify(struct rspamd_task *task, struct rspamd_config *cfg) +``` + +**Flow:** + +1. Task arrives for classification +2. Iterates through configured classifiers +3. For each classifier, calls `rspamd_stat_classifiers[i].classify_func()` +4. For Bayes: calls `bayes_classify_multiclass()` + +### 1.2 Bayes Classification Entry + +```c +// src/libstat/classifiers/bayes.c +gboolean bayes_classify_multiclass(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task) +``` + +**Key Steps:** + +1. Validates `ctx->cfg->class_names` array +2. Sets up `bayes_task_closure` with class information +3. **Calls Redis backend to fetch token data** +4. Processes returned token values +5. Calculates probabilities and inserts symbols + +## 2. Redis Backend Data Flow + +### 2.1 Backend Runtime Creation + +```cpp +// src/libstat/backends/redis_backend.cxx +gpointer rspamd_redis_runtime(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, + gboolean learn, gpointer c, int _id) +``` + +**Runtime Structure:** + +```cpp +template<class T> +class redis_stat_runtime { + struct redis_stat_ctx *ctx; // Redis connection context + struct rspamd_task *task; // Current task + struct rspamd_statfile_config *stcf; // Statfile configuration + const char *redis_object_expanded; // Expanded key prefix + int id; // Statfile ID (critical!) + std::optional<std::map<int, T>> results; // Token index → value mapping +}; +``` + +**Critical Insight: Statfile ID Mapping** + +- Each statfile has a unique ID (`id`) +- Token values are stored in `tok->values[id]` array +- **The `id` must match exactly between runtime and statfile** + +### 2.2 Multiple Runtime Creation (Classification Mode) + +For multiclass classification, the system creates multiple runtimes: + +```cpp +// For each statfile in classifier +for (cur = stcf->clcf->statfiles; cur; cur = g_list_next(cur)) { + auto *other_stcf = (struct rspamd_statfile_config *) cur->data; + + // Find correct statfile ID + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + int other_id = -1; + for (i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *st = g_ptr_array_index(st_ctx->statfiles, i); + if (st->stcf == other_stcf) { + other_id = st->id; // ← This is the critical mapping! + break; + } + } + + // Create runtime with correct ID + auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + other_rt->id = other_id; // ← Must be set correctly! +} +``` + +### 2.3 Token Processing Call + +```cpp +gboolean rspamd_redis_process_tokens(struct rspamd_task *task, + GPtrArray *tokens, + int id, gpointer p) +``` + +**Flow:** + +1. Serializes tokens to MessagePack format +2. Builds class labels string (e.g., "TABLE:H,S,N,T") +3. Calls Lua function to execute Redis script +4. Registers callback for async result processing + +## 3. Lua Script Layer + +### 3.1 Lua Function Entry Point + +```lua +-- lualib/lua_bayes_redis.lua +local function gen_classify_functor(redis_params, classify_script_id) + return function(task, expanded_key, id, stat_tokens, callback) + -- Executes Redis script via lua_redis + lua_redis.exec_redis_script(classify_script_id, + { task = task, is_write = false, key = expanded_key }, + classify_redis_cb, + { expanded_key, class_labels, stat_tokens }) + end +end +``` + +**Key Components:** + +- `expanded_key`: Redis key prefix (e.g., "BAYES{user@domain}") +- `class_labels`: "TABLE:H,S,N,T" format for multiclass +- `stat_tokens`: MessagePack-encoded token array +- `callback`: Function to handle Redis script results + +### 3.2 Class Labels Format + +**Critical Detail**: The class labels format determines Redis script behavior: + +```lua +-- Binary mode (legacy) +class_labels = "H" -- Single class + +-- Multiclass mode +class_labels = "TABLE:H,S,N,T" -- Multiple classes with TABLE: prefix +``` + +## 4. Redis Script Execution + +### 4.1 Script Structure + +```lua +-- lualib/redis_scripts/bayes_classify.lua +local prefix = KEYS[1] -- "BAYES{user@domain}" +local class_labels_arg = KEYS[2] -- "TABLE:H,S,N,T" +local input_tokens = cmsgpack.unpack(KEYS[3]) -- [tok1, tok2, ...] +``` + +### 4.2 Class Label Parsing + +```lua +local class_labels = {} +if string.match(class_labels_arg, "^TABLE:") then + -- Multiclass mode + local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" + for label in string.gmatch(labels_str, "([^,]+)") do + table.insert(class_labels, label) -- ["H", "S", "N", "T"] + end +else + -- Binary mode (single label) + table.insert(class_labels, class_labels_arg) +end +``` + +### 4.3 Redis Key Structure + +**Learning Counts:** + +``` +BAYES{user@domain}_H_learns → { learns: 1500 } +BAYES{user@domain}_S_learns → { learns: 800 } +BAYES{user@domain}_N_learns → { learns: 200 } +BAYES{user@domain}_T_learns → { learns: 150 } +``` + +**Token Counts:** + +``` +BAYES{user@domain}_H_tokens → { token1: 45, token2: 12, ... } +BAYES{user@domain}_S_tokens → { token1: 23, token2: 67, ... } +BAYES{user@domain}_N_tokens → { token1: 5, token2: 8, ... } +BAYES{user@domain}_T_tokens → { token1: 2, token2: 3, ... } +``` + +### 4.4 Token Lookup Process + +```lua +-- Get learning counts for each class +local learned_counts = {} +for i, class_label in ipairs(class_labels) do + local learns_key = prefix .. "_" .. class_label .. "_learns" + learned_counts[i] = tonumber(redis.call('HGET', learns_key, 'learns') or '0') +end + +-- Batch token lookup for all classes +local pipe = redis.call('MULTI') +for i, token in ipairs(input_tokens) do + for j, class_label in ipairs(class_labels) do + local token_key = prefix .. "_" .. class_label .. "_tokens" + redis.call('HGET', token_key, token) + end +end +local token_results = redis.call('EXEC') + +-- Parse results into ordered arrays +local token_data = {} +for j, class_label in ipairs(class_labels) do + token_data[j] = {} -- token_data[class_index][token_index] = count +end + +local result_idx = 1 +for i, token in ipairs(input_tokens) do + for j, class_label in ipairs(class_labels) do + local count = tonumber(token_results[result_idx]) or 0 + if count > 0 then + table.insert(token_data[j], {i, count}) -- {token_index, count} + end + result_idx = result_idx + 1 + end +end + +-- Return: [learned_counts, token_data] +return {learned_counts, token_data} +``` + +### 4.5 Return Format + +**Redis Script Returns:** + +```lua +{ + [1] = {1500, 800, 200, 150}, -- learned_counts per class + [2] = { -- token_data per class + [1] = {{1,45}, {2,12}, ...}, -- Class H tokens: {token_idx, count} + [2] = {{1,23}, {2,67}, ...}, -- Class S tokens + [3] = {{1,5}, {2,8}, ...}, -- Class N tokens + [4] = {{1,2}, {2,3}, ...} -- Class T tokens + } +} +``` + +## 5. Result Processing in C++ + +### 5.1 Redis Callback Handler + +```cpp +// src/libstat/backends/redis_backend.cxx +static int rspamd_redis_classified(lua_State *L) +{ + auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + bool result = lua_toboolean(L, 2); + + if (result && lua_istable(L, 3)) { + // Process learned_counts (table index 1) + lua_rawgeti(L, 3, 1); + if (lua_istable(L, -1)) { + // Store learned counts (implementation detail) + } + lua_pop(L, 1); + + // Process token_results (table index 2) + lua_rawgeti(L, 3, 2); + if (lua_istable(L, -1)) { + process_multiclass_token_results(L, rt, task); + } + lua_pop(L, 1); + } +} +``` + +### 5.2 Token Results Processing + +```cpp +static void process_multiclass_token_results(lua_State *L, + redis_stat_runtime<float> *rt, + struct rspamd_task *task) +{ + // L stack: token_results table at top + // Format: {[1] = {{1,45}, {2,12}}, [2] = {{1,23}, {2,67}}, ...} + + if (rt->stcf->clcf && rt->stcf->clcf->statfiles) { + GList *cur = rt->stcf->clcf->statfiles; + int class_idx = 1; + + while (cur) { + auto *stcf = (struct rspamd_statfile_config *)cur->data; + + // Find correct statfile ID + int statfile_id = find_statfile_id_for_config(stcf); + + // Get or create runtime for this statfile + auto maybe_statfile_rt = get_runtime_for_statfile(task, stcf, statfile_id); + if (maybe_statfile_rt) { + auto *statfile_rt = maybe_statfile_rt.value(); + + // Get token data for this class (class_idx) + lua_rawgeti(L, -1, class_idx); + if (lua_istable(L, -1)) { + parse_class_token_data(L, statfile_rt); + } + lua_pop(L, 1); + } + + cur = g_list_next(cur); + class_idx++; + } + } +} +``` + +### 5.3 Token Value Assignment + +```cpp +bool redis_stat_runtime<T>::process_tokens(GPtrArray *tokens) const +{ + rspamd_token_t *tok; + + if (!results) { + return false; + } + + // results maps: token_index → token_count + for (auto [token_idx, token_count] : *results) { + tok = (rspamd_token_t *) g_ptr_array_index(tokens, token_idx - 1); + + // CRITICAL: Set tok->values[id] where id is the statfile ID + tok->values[id] = token_count; + } + + return true; +} +``` + +## 6. Classification Algorithm Execution + +### 6.1 Multiclass Processing + +```c +// src/libstat/classifiers/bayes.c +gboolean bayes_classify_multiclass(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task) +{ + struct bayes_task_closure cl; + + // Initialize with class information from config + cl.num_classes = ctx->cfg->class_names->len; + cl.class_names = (char**)ctx->cfg->class_names->pdata; + + // Process all tokens + for (i = 0; i < tokens->len; i++) { + rspamd_token_t *tok = g_ptr_array_index(tokens, i); + bayes_classify_token_multiclass(ctx, tok, &cl); + } +} +``` + +### 6.2 Token Classification + +```c +static void bayes_classify_token_multiclass(struct rspamd_classifier *ctx, + rspamd_token_t *tok, + struct bayes_task_closure *cl) +{ + // For each statfile, check if it has data for this token + for (i = 0; i < ctx->statfiles_ids->len; i++) { + int id = g_array_index(ctx->statfiles_ids, int, i); + struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id); + + // CRITICAL: tok->values[id] must be set by Redis backend + double val = tok->values[id]; + + if (val > 0) { + // Find which class this statfile belongs to + for (j = 0; j < cl->num_classes; j++) { + if (strcmp(st->stcf->class_name, cl->class_names[j]) == 0) { + // Accumulate token evidence for this class + process_token_for_class(cl, j, val, st); + break; + } + } + } + } +} +``` + +## 7. Critical Data Mapping + +### 7.1 Statfile ID Assignment + +**The Core Problem**: Ensuring correct mapping between: + +1. **Redis script class order**: `[H, S, N, T]` (array indices 1,2,3,4) +2. **Statfile IDs**: Global statfile IDs assigned by `rspamd_stat_get_ctx()` +3. **Runtime IDs**: Must match statfile IDs for `tok->values[id]` assignment + +### 7.2 Configuration to Runtime Mapping + +```c +// Configuration defines classes +statfile "BAYES_HAM" { class = "ham"; symbol = "BAYES_HAM"; } // Gets ID=0 +statfile "BAYES_SPAM" { class = "spam"; symbol = "BAYES_SPAM"; } // Gets ID=1 +statfile "BAYES_NEWS" { class = "news"; symbol = "BAYES_NEWS"; } // Gets ID=2 + +// Redis backend maps: class_name → backend_label +class_labels = { + "ham" = "H"; // Maps to Redis "H" + "spam" = "S"; // Maps to Redis "S" + "news" = "N"; // Maps to Redis "N" +} + +// Redis script processes in label order: ["H", "S", "N"] +// Returns data in same order: [ham_data, spam_data, news_data] + +// C++ must map: +// redis_result[0] → statfile_id=0 (ham) +// redis_result[1] → statfile_id=1 (spam) +// redis_result[2] → statfile_id=2 (news) +``` + +### 7.3 Token Array Structure + +```c +// For each token in message +struct rspamd_token { + uint64_t data; // Token hash + float values[MAX_STATFILES]; // Values per statfile ID + // ... +}; + +// After Redis processing: +// tok->values[0] = ham_count (from redis_result[0]) +// tok->values[1] = spam_count (from redis_result[1]) +// tok->values[2] = news_count (from redis_result[2]) +``` diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx index 0f55a725c..f6ca9c12d 100644 --- a/src/libstat/backends/cdb_backend.cxx +++ b/src/libstat/backends/cdb_backend.cxx @@ -393,7 +393,6 @@ rspamd_cdb_process_tokens(struct rspamd_task *task, gpointer runtime) { auto *cdbp = CDB_FROM_RAW(runtime); - bool seen_values = false; for (auto i = 0u; i < tokens->len; i++) { rspamd_token_t *tok; @@ -403,21 +402,13 @@ rspamd_cdb_process_tokens(struct rspamd_task *task, if (res) { tok->values[id] = res.value(); - seen_values = true; } else { tok->values[id] = 0; } } - if (seen_values) { - if (cdbp->is_spam()) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } - } + /* No longer need to set flags - multi-class handles missing data naturally */ return true; } @@ -488,4 +479,4 @@ void rspamd_cdb_close(gpointer ctx) { auto *cdbp = CDB_FROM_RAW(ctx); delete cdbp; -}
\ No newline at end of file +} diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c index 4430bb9a4..a6423a1e6 100644 --- a/src/libstat/backends/mmaped_file.c +++ b/src/libstat/backends/mmaped_file.c @@ -85,8 +85,7 @@ typedef struct { #define RSPAMD_STATFILE_VERSION \ { \ - '1', '2' \ - } + '1', '2'} #define BACKUP_SUFFIX ".old" static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool, @@ -958,12 +957,7 @@ rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens, tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2); } - if (mf->cf->is_spam) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } + /* No longer need to set flags - multi-class handles missing data naturally */ return TRUE; } diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 7137904e9..3a78de1dd 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -22,6 +22,7 @@ #include "contrib/fmt/include/fmt/base.h" #include "libutil/cxx/error.hxx" +#include <map> #include <string> #include <cstdint> @@ -121,9 +122,9 @@ public: } static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded, - bool is_spam) -> std::optional<redis_stat_runtime<T> *> + const char *class_label) -> std::optional<redis_stat_runtime<T> *> { - auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label); auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str()); if (res) { @@ -147,9 +148,15 @@ public: rspamd_token_t *tok; if (!results) { + msg_debug_bayes("process_tokens: no results available for statfile id=%d", id); return false; } + if (results->size() > 0) { + msg_debug_bayes("processing %uz tokens for statfile id=%d, class=%s", + results->size(), id, stcf->class_name ? stcf->class_name : "unknown"); + } + for (auto [idx, val]: *results) { tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1); tok->values[id] = val; @@ -158,12 +165,14 @@ public: return true; } - auto save_in_mempool(bool is_spam) const + auto save_in_mempool(const char *class_label) const { - auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto var_name = + rspamd_mempool_strdup(task->task_pool, + fmt::format("{}_{}", redis_object_expanded, class_label).c_str()); /* We do not set destructor for the variable, as it should be already added on creation */ - rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr); - msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str()); + rspamd_mempool_set_variable(task->task_pool, var_name, (gpointer) this, nullptr); + msg_debug_bayes("saved runtime in mempool at %s", var_name); } }; @@ -178,6 +187,26 @@ rspamd_redis_stat_quark(void) } /* + * Get the class label for a statfile (for multi-class support) + */ +static const char * +get_class_label(struct rspamd_statfile_config *stcf) +{ + /* Try to get the label from the classifier config first */ + if (stcf->clcf && stcf->clcf->class_labels && stcf->class_name) { + const char *label = rspamd_config_get_class_label(stcf->clcf, stcf->class_name); + if (label) { + return label; + } + /* If no label mapping found, use class name directly */ + return stcf->class_name; + } + + /* Fallback to legacy binary classification */ + return stcf->is_spam ? "S" : "H"; +} + +/* * Non-static for lua unit testing */ gsize rspamd_redis_expand_object(const char *pattern, @@ -235,6 +264,11 @@ gsize rspamd_redis_expand_object(const char *pattern, if (rcpt) { rspamd_mempool_set_variable(task->task_pool, "stat_user", (gpointer) rcpt, nullptr); + msg_debug_bayes("redis expansion: found recipient '%s'", rcpt); + } + else { + msg_debug_bayes("redis expansion: no recipient found (deliver_to=%s)", + task->deliver_to ? task->deliver_to : "null"); } } @@ -448,6 +482,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, users_enabled = ucl_object_lookup_any(classifier_obj, "per_user", "users_enabled", nullptr); + msg_debug_bayes_cfg("per-user lookup: users_enabled=%p", users_enabled); if (users_enabled != nullptr) { if (ucl_object_type(users_enabled) == UCL_BOOLEAN) { backend->enable_users = ucl_object_toboolean(users_enabled); @@ -485,9 +520,16 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, /* Default non-users statistics */ if (backend->enable_users || backend->cbref_user != -1) { backend->redis_object = REDIS_DEFAULT_USERS_OBJECT; + msg_debug_bayes_cfg("using per-user Redis pattern: %s (enable_users=%s, cbref_user=%d)", + backend->redis_object, backend->enable_users ? "true" : "false", + backend->cbref_user); } else { backend->redis_object = REDIS_DEFAULT_OBJECT; + msg_debug_bayes_cfg("using default Redis pattern: %s (enable_users=%s, cbref_user=%d)", + backend->redis_object, + backend->enable_users ? "true" : "false", + backend->cbref_user); } } else { @@ -541,7 +583,7 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx, ucl_object_push_lua(L, st->classifier->cfg->opts, false); ucl_object_push_lua(L, st->stcf->opts, false); lua_pushstring(L, backend->stcf->symbol); - lua_pushboolean(L, backend->stcf->is_spam); + lua_pushstring(L, get_class_label(backend->stcf)); /* Pass class label instead of boolean */ /* Push event loop if there is one available (e.g. we are not in rspamadm mode) */ if (ctx->event_loop) { @@ -606,11 +648,20 @@ rspamd_redis_runtime(struct rspamd_task *task, stcf->symbol); return nullptr; } + else { + msg_debug_bayes("redis object expanded: pattern='%s' -> expanded='%s' (learn=%s, symbol=%s)", + ctx->redis_object ? ctx->redis_object : "default", + object_expanded, + learn ? "true" : "false", + stcf->symbol); + } + + const char *class_label = get_class_label(stcf); /* Look for the cached results */ if (!learn) { auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task, - object_expanded, stcf->is_spam); + object_expanded, class_label); if (maybe_existing) { auto *rt = maybe_existing.value(); @@ -624,24 +675,62 @@ rspamd_redis_runtime(struct rspamd_task *task, /* No cached result (or learn), create new one */ auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); - if (!learn) { - /* - * For check, we also need to create the opposite class runtime to avoid - * double call for Redis scripts. - * This runtime will be filled later. - */ - auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, - object_expanded, - !stcf->is_spam); - - if (!maybe_opposite_rt) { - auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); - opposite_rt->save_in_mempool(!stcf->is_spam); - opposite_rt->need_redis_call = false; + /* Find the statfile ID for the main runtime */ + int main_id = _id; /* Use the passed _id parameter */ + rt->id = main_id; + rt->stcf = stcf; + + /* For classification, create runtimes for all other statfiles to avoid multiple Redis calls */ + if (!learn && stcf->clcf && stcf->clcf->statfiles) { + GList *cur = stcf->clcf->statfiles; + + while (cur) { + auto *other_stcf = (struct rspamd_statfile_config *) cur->data; + const char *other_label = get_class_label(other_stcf); + + /* Find the statfile ID by searching through all statfiles */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + int other_id = -1; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *st = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (st->stcf == other_stcf) { + other_id = st->id; + msg_debug_bayes("found statfile mapping: %s (class=%s) → id=%d", + st->stcf->symbol, other_label, other_id); + break; + } + } + + if (other_id == -1) { + msg_debug_bayes("statfile not found for class %s, skipping", other_label); + /* Skip if statfile not found */ + cur = g_list_next(cur); + continue; + } + + if (other_stcf == stcf) { + /* This is the main statfile, use the main runtime */ + rt->save_in_mempool(other_label); + msg_debug_bayes("main runtime: statfile %s (class=%s) → id=%d", + stcf->symbol, other_label, rt->id); + } + else { + /* Create additional runtime for other statfile */ + auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + other_rt->id = other_id; + other_rt->stcf = other_stcf; + other_rt->need_redis_call = false; + other_rt->save_in_mempool(other_label); + msg_debug_bayes("additional runtime: statfile %s (class=%s) → id=%d", + other_stcf->symbol, other_label, other_id); + } + + cur = g_list_next(cur); } } - - rt->save_in_mempool(stcf->is_spam); + else { + rt->save_in_mempool(class_label); + } return rt; } @@ -816,77 +905,306 @@ rspamd_redis_classified(lua_State *L) if (rt == nullptr) { msg_err_task("internal error: cannot find runtime for cookie %s", cookie); - return 0; } bool result = lua_toboolean(L, 2); if (result) { - /* Indexes: - * 3 - learned_ham (int) - * 4 - learned_spam (int) - * 5 - ham_tokens (pair<int, int>) - * 6 - spam_tokens (pair<int, int>) - */ - - /* - * We need to fill our runtime AND the opposite runtime - */ - auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) { - rt->learned = learned; - redis_stat_runtime<float>::result_type *res; - - res = new redis_stat_runtime<float>::result_type(); - - for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) { - lua_rawgeti(L, -1, 1); - auto idx = lua_tointeger(L, -1); - lua_pop(L, 1); - - lua_rawgeti(L, -1, 2); - auto value = lua_tonumber(L, -1); - lua_pop(L, 1); - - res->emplace_back(idx, value); - } - - rt->set_results(res); - }; - - auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task, - rt->redis_object_expanded, - !rt->stcf->is_spam); + /* Check we have enough arguments and the result data is a table */ + if (lua_gettop(L) < 3 || !lua_istable(L, 3)) { + msg_err_task("internal error: expected table result from Redis script, got %s", + lua_typename(L, lua_type(L, 3))); + rt->err = rspamd::util::error("invalid Redis script result format", 500); + return 0; + } - if (!opposite_rt_maybe) { - msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); + /* Redis returns [learned_counts_array, token_results_array] + * Both ordered the same way as statfiles in classifier */ + size_t result_len = rspamd_lua_table_size(L, 3); + msg_debug_bayes("Redis result array length: %uz", result_len); + if (result_len != 2) { + msg_err_task("internal error: expected 2-element result from Redis script, got %uz", result_len); + rt->err = rspamd::util::error("invalid Redis script result format", 500); return 0; } - if (rt->stcf->is_spam) { - filler_func(rt, L, lua_tointeger(L, 4), 6); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + /* Get learned_counts_array and token_results_array */ + lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */ + lua_rawgeti(L, 3, 2); /* token_results -> position 5 */ + + /* First, process learned_counts */ + if (lua_istable(L, 4) && rt->stcf->clcf) { + if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: use class_names order */ + for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); + + /* Find statfile with this class name */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + const char *class_label = get_class_label(stcf); + + /* Get the runtime for this statfile */ + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + + /* Extract learned count using class index (1-based for Lua) */ + lua_rawgeti(L, 4, class_idx + 1); + if (lua_isnumber(L, -1)) { + statfile_rt->learned = lua_tointeger(L, -1); + msg_debug_bayes("set learned count for class %s (label %s): %L", + class_name, class_label, statfile_rt->learned); + } + lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */ + } + break; /* Found the statfile for this class */ + } + cur = g_list_next(cur); + } + } + } + else { + /* Binary classification: process statfiles in order */ + GList *cur = rt->stcf->clcf->statfiles; + unsigned int statfile_idx = 0; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); + + /* Get the runtime for this statfile */ + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + + /* Extract learned count using statfile index (1-based for Lua) */ + lua_rawgeti(L, 4, statfile_idx + 1); + if (lua_isnumber(L, -1)) { + statfile_rt->learned = lua_tointeger(L, -1); + msg_debug_bayes("set learned count for statfile %s (label %s): %L", + stcf->symbol, class_label, statfile_rt->learned); + } + lua_pop(L, 1); /* Pop learned_counts[statfile_idx + 1] */ + } + cur = g_list_next(cur); + statfile_idx++; + } + } } - else { - filler_func(rt, L, lua_tointeger(L, 3), 5); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); + + /* Process token results */ + if (lua_istable(L, 5) && rt->stcf->clcf) { + if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: use class_names order */ + for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); + + /* Find statfile with this class name */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + const char *class_label = get_class_label(stcf); + + /* Find the statfile ID */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + struct rspamd_statfile *st = nullptr; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (candidate->stcf == stcf) { + st = candidate; + break; + } + } + + if (!st) { + msg_debug_bayes("statfile not found for class %s, skipping", class_name); + break; + } + + /* Get or create runtime for this statfile */ + auto *statfile_rt = rt; /* Use current runtime if it matches */ + if (stcf != rt->stcf) { + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + statfile_rt = maybe_rt.value(); + } + else { + msg_debug_bayes("runtime not found for class %s, skipping", class_label); + break; + } + } + + /* Ensure correct statfile ID assignment */ + statfile_rt->id = st->id; + + /* Process token results using class index (1-based for Lua) */ + lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */ + if (lua_istable(L, -1)) { + /* Parse token results into statfile runtime */ + auto *res = new std::vector<std::pair<int, float>>(); + + lua_pushnil(L); /* First key for iteration */ + while (lua_next(L, -2) != 0) { + if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { + lua_rawgeti(L, -1, 1); /* token_index */ + lua_rawgeti(L, -2, 2); /* token_count */ + + if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { + int token_idx = lua_tointeger(L, -2); + float token_count = lua_tonumber(L, -1); + res->emplace_back(token_idx, token_count); + } + + lua_pop(L, 2); /* Pop token_index and token_count */ + } + lua_pop(L, 1); /* Pop value, keep key for next iteration */ + } + + statfile_rt->set_results(res); + } + lua_pop(L, 1); /* Pop token_results[class_idx + 1] */ + break; /* Found the statfile for this class */ + } + cur = g_list_next(cur); + } + } + } + else { + /* Binary classification: process statfiles in order */ + GList *cur = rt->stcf->clcf->statfiles; + unsigned int statfile_idx = 0; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); + + /* Find the statfile ID */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + struct rspamd_statfile *st = nullptr; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (candidate->stcf == stcf) { + st = candidate; + break; + } + } + + if (!st) { + msg_debug_bayes("statfile not found for %s, skipping", stcf->symbol); + cur = g_list_next(cur); + statfile_idx++; + continue; + } + + /* Get or create runtime for this statfile */ + auto *statfile_rt = rt; /* Use current runtime if it matches */ + if (stcf != rt->stcf) { + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + statfile_rt = maybe_rt.value(); + } + else { + msg_debug_bayes("runtime not found for %s, skipping", class_label); + cur = g_list_next(cur); + statfile_idx++; + continue; + } + } + + /* Ensure correct statfile ID assignment */ + statfile_rt->id = st->id; + + /* Process token results using statfile index (1-based for Lua) */ + lua_rawgeti(L, 5, statfile_idx + 1); /* Get token_results[statfile_idx + 1] */ + if (lua_istable(L, -1)) { + /* Parse token results into statfile runtime */ + auto *res = new std::vector<std::pair<int, float>>(); + + lua_pushnil(L); /* First key for iteration */ + while (lua_next(L, -2) != 0) { + if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { + lua_rawgeti(L, -1, 1); /* token_index */ + lua_rawgeti(L, -2, 2); /* token_count */ + + if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { + int token_idx = lua_tointeger(L, -2); + float token_count = lua_tonumber(L, -1); + res->emplace_back(token_idx, token_count); + } + + lua_pop(L, 2); /* Pop token_index and token_count */ + } + lua_pop(L, 1); /* Pop value, keep key for next iteration */ + } + + statfile_rt->set_results(res); + msg_debug_bayes("set %uz token results for statfile %s (label %s, id=%d)", + res->size(), stcf->symbol, class_label, st->id); + } + lua_pop(L, 1); /* Pop token_results[statfile_idx + 1] */ + + cur = g_list_next(cur); + statfile_idx++; + } + } } - /* Mark task as being processed */ - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + /* Clean up stack */ + lua_pop(L, 2); /* Pop learned_counts and token_results */ - /* Process all tokens */ + /* Process tokens for all runtimes */ g_assert(rt->tokens != nullptr); - rt->process_tokens(rt->tokens); - opposite_rt_maybe.value()->process_tokens(rt->tokens); + + /* Process tokens for all statfiles */ + if (rt->stcf->clcf && rt->stcf->clcf->statfiles) { + GList *cur = rt->stcf->clcf->statfiles; + + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); + + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + statfile_rt->process_tokens(rt->tokens); + } + + cur = g_list_next(cur); + } + } + else { + /* Fallback: just process the main runtime */ + rt->process_tokens(rt->tokens); + } } else { /* Error message is on index 3 */ - const auto *err_msg = lua_tostring(L, 3); - rt->err = rspamd::util::error(err_msg, 500); - msg_err_task("cannot classify task: %s", - err_msg); + const char *err_msg = nullptr; + if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) { + err_msg = lua_tostring(L, 3); + } + if (err_msg) { + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot classify task: %s", err_msg); + } + else { + rt->err = rspamd::util::error("unknown Redis script error", 500); + msg_err_task("cannot classify task: unknown Redis script error"); + } } return 0; @@ -929,7 +1247,57 @@ rspamd_redis_process_tokens(struct rspamd_task *task, rspamd_lua_task_push(L, task); lua_pushstring(L, rt->redis_object_expanded); lua_pushinteger(L, id); - lua_pushboolean(L, rt->stcf->is_spam); + + /* Send all class labels for multi-class support */ + if (rt->stcf->clcf && rt->stcf->clcf->class_names && + rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: send array of class labels in deterministic order */ + lua_createtable(L, rt->stcf->clcf->class_names->len, 0); + for (unsigned int i = 0; i < rt->stcf->clcf->class_names->len; i++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, i); + const char *class_label = nullptr; + + /* Find the class label for this class name from any statfile with this class */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + class_label = get_class_label(stcf); + break; + } + cur = g_list_next(cur); + } + + if (class_label) { + lua_pushstring(L, class_label); + lua_rawseti(L, -2, i + 1); /* Lua arrays are 1-indexed */ + } + } + } + else { + /* Binary classification: send labels in statfiles order to match parsing order */ + if (rt->stcf->clcf && rt->stcf->clcf->statfiles) { + lua_createtable(L, 0, 0); + GList *cur = rt->stcf->clcf->statfiles; + int lbl_idx = 1; + + while (cur) { + auto *sf = (struct rspamd_statfile_config *) cur->data; + lua_pushstring(L, get_class_label(sf)); + lua_rawseti(L, -2, lbl_idx++); + cur = g_list_next(cur); + } + } + else { + /* Fallback to the legacy order if statfiles are not available */ + lua_createtable(L, 2, 0); + lua_pushstring(L, "H"); /* ham */ + lua_rawseti(L, -2, 1); + lua_pushstring(L, "S"); /* spam */ + lua_rawseti(L, -2, 2); + } + } + lua_new_text(L, tokens_buf, tokens_len, false); /* Store rt in random cookie */ @@ -979,13 +1347,31 @@ rspamd_redis_learned(lua_State *L) bool result = lua_toboolean(L, 2); if (result) { - /* TODO: write it */ + /* Learning successful - no complex data to process like in classification */ + msg_debug_bayes("learned tokens successfully in Redis for symbol %s, class %s", + rt->stcf->symbol, get_class_label(rt->stcf)); + + /* Clear any previous error state */ + rt->err = std::nullopt; + + /* Learning operations don't return data structures to process, + * they just update Redis state. Success means the Redis script + * completed without errors. */ } else { /* Error message is on index 3 */ - const auto *err_msg = lua_tostring(L, 3); - rt->err = rspamd::util::error(err_msg, 500); - msg_err_task("cannot learn task: %s", err_msg); + const char *err_msg = nullptr; + if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) { + err_msg = lua_tostring(L, 3); + } + if (err_msg) { + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot learn task: %s", err_msg); + } + else { + rt->err = rspamd::util::error("unknown Redis script error", 500); + msg_err_task("cannot learn task: unknown Redis script error"); + } } return 0; @@ -1028,7 +1414,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, rspamd_lua_task_push(L, task); lua_pushstring(L, rt->redis_object_expanded); lua_pushinteger(L, id); - lua_pushboolean(L, rt->stcf->is_spam); + lua_pushstring(L, get_class_label(rt->stcf)); /* Pass class label instead of boolean */ lua_pushstring(L, rt->stcf->symbol); /* Detect unlearn */ @@ -1056,6 +1442,8 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, lua_new_text(L, text_tokens_buf, text_tokens_len, false); } + msg_debug_bayes("called lua learn script for %s (cookie=%s)", rt->stcf->symbol, cookie); + if (lua_pcall(L, nargs, 0, err_idx) != 0) { msg_err_task("call to script failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); diff --git a/src/libstat/backends/sqlite3_backend.c b/src/libstat/backends/sqlite3_backend.c index 973dc30a7..8f29a3b4e 100644 --- a/src/libstat/backends/sqlite3_backend.c +++ b/src/libstat/backends/sqlite3_backend.c @@ -589,12 +589,7 @@ rspamd_sqlite3_process_tokens(struct rspamd_task *task, } } - if (rt->cf->is_spam) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } + /* No longer need to set flags - multi-class handles missing data naturally */ } diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index 93b5149da..dbae98cc2 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -53,10 +53,26 @@ static double inv_chi_square(struct rspamd_task *task, double value, int freedom_deg) { double prob, sum, m; + double log_prob, log_m; int i; errno = 0; m = -value; + + /* Handle extreme negative values that would cause exp() underflow */ + if (value < -700) { + /* Very strong confidence, return 0 */ + msg_debug_bayes("extreme negative value: %f, returning 0", value); + return 0.0; + } + + /* Handle extreme positive values that would cause overflow */ + if (value > 700) { + /* No confidence, return 1 */ + msg_debug_bayes("extreme positive value: %f, returning 1", value); + return 1.0; + } + prob = exp(value); if (errno == ERANGE) { @@ -75,6 +91,8 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg) } sum = prob; + log_prob = value; /* log of current prob term */ + log_m = log(fabs(m)); /* log of |m| for numerical stability */ msg_debug_bayes("m: %f, probability: %g", m, prob); @@ -83,24 +101,60 @@ inv_chi_square(struct rspamd_task *task, double value, int freedom_deg) * prob is e ^ x (small value since x is normally less than zero * So we integrate over degrees of freedom and produce the total result * from 1.0 (no confidence) to 0.0 (full confidence) + * Use logarithmic arithmetic to prevent overflow */ for (i = 1; i < freedom_deg; i++) { - prob *= m / (double) i; + /* Calculate next term using logarithms to prevent overflow */ + log_prob += log_m - log((double) i); + + /* Check if the log probability is too negative (term becomes negligible) */ + if (log_prob < -700) { + msg_debug_bayes("term %d became negligible, stopping series", i); + break; + } + + /* Check if the log probability is too positive (would cause overflow) */ + if (log_prob > 700) { + msg_debug_bayes("series diverging at term %d, returning 1.0", i); + return 1.0; + } + + prob = exp(log_prob); sum += prob; - msg_debug_bayes("i=%d, probability: %g, sum: %g", i, prob, sum); + msg_debug_bayes("i=%d, log_prob: %g, probability: %g, sum: %g", i, log_prob, prob, sum); + + /* Early termination if sum is getting too large */ + if (sum > 1e10) { + msg_debug_bayes("sum too large (%g), returning 1.0", sum); + return 1.0; + } } return MIN(1.0, sum); } struct bayes_task_closure { - double ham_prob; - double spam_prob; + double ham_prob; /* Kept for binary compatibility */ + double spam_prob; /* Kept for binary compatibility */ + double meta_skip_prob; + uint64_t processed_tokens; + uint64_t total_hits; + uint64_t text_tokens; + struct rspamd_task *task; +}; + +/* Multi-class classification closure */ +struct bayes_multiclass_closure { + double *class_log_probs; /* Array of log probabilities for each class */ + uint64_t *class_learns; /* Learning counts for each class */ + char **class_names; /* Array of class names */ + unsigned int num_classes; /* Number of classes */ double meta_skip_prob; uint64_t processed_tokens; uint64_t total_hits; uint64_t text_tokens; struct rspamd_task *task; + struct rspamd_classifier_config *cfg; }; /* @@ -122,7 +176,6 @@ bayes_classify_token(struct rspamd_classifier *ctx, unsigned int spam_count = 0, ham_count = 0, total_count = 0; struct rspamd_statfile *st; struct rspamd_task *task; - const char *token_type = "txt"; double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob, ham_prob, fw, w, val; @@ -211,41 +264,379 @@ bayes_classify_token(struct rspamd_classifier *ctx, if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { cl->text_tokens++; } + } +} + +/* + * Multinomial token classification for multi-class Bayes + */ +static void +bayes_classify_token_multiclass(struct rspamd_classifier *ctx, + rspamd_token_t *tok, + struct bayes_multiclass_closure *cl) +{ + unsigned int i, j; + int id; + struct rspamd_statfile *st; + struct rspamd_task *task; + double val, fw, w; + guint64 *class_counts; + guint64 total_count = 0; + + task = cl->task; + + /* Skip meta tokens probabilistically if configured */ + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) { + val = rspamd_random_double_fast(); + if (val <= cl->meta_skip_prob) { + return; + } + } + + /* Allocate array for class counts */ + class_counts = g_alloca(cl->num_classes * sizeof(guint64)); + memset(class_counts, 0, cl->num_classes * sizeof(guint64)); + + /* Collect counts for each class */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, int, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + val = tok->values[id]; + + if (val > 0) { + /* Direct O(1) class index lookup instead of O(N) string comparison */ + if (st->stcf->class_name && st->stcf->class_index < cl->num_classes) { + unsigned int class_idx = st->stcf->class_index; + class_counts[class_idx] += val; + total_count += val; + cl->total_hits += val; + } + else { + msg_debug_bayes("invalid class_index %ud >= %ud for statfile %s", + st->stcf->class_index, cl->num_classes, st->stcf->symbol); + } + } + } + + /* Calculate multinomial probability for this token */ + if (total_count >= ctx->cfg->min_token_hits) { + /* Feature weight calculation */ + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) { + fw = 1.0; + } else { - token_type = "meta"; + fw = feature_weight[tok->window_idx % G_N_ELEMENTS(feature_weight)]; } - if (tok->t1 && tok->t2) { - msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, " - "total_count: %ud, " - "spam_count: %ud, ham_count: %ud," - "spam_prob: %.3f, ham_prob: %.3f, " - "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, " - "current spam probability: %.3f, current ham probability: %.3f", - token_type, - tok->data, - (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, - (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, - fw, w, total_count, spam_count, ham_count, - spam_prob, ham_prob, - bayes_spam_prob, bayes_ham_prob, - cl->spam_prob, cl->ham_prob); + w = (fw * total_count) / (1.0 + fw * total_count); + + /* Apply multinomial model for each class */ + for (j = 0; j < cl->num_classes; j++) { + /* Skip classes with insufficient learns */ + if (ctx->cfg->min_learns > 0 && cl->class_learns[j] < ctx->cfg->min_learns) { + continue; + } + + double class_freq = (double) class_counts[j] / MAX(1.0, (double) cl->class_learns[j]); + double class_prob = PROB_COMBINE(class_freq, total_count, w, 1.0 / cl->num_classes); + + /* Ensure probability is properly bounded [0, 1] */ + class_prob = MAX(0.0, MIN(1.0, class_prob)); + + /* Skip probabilities too close to uniform (1/num_classes) */ + double uniform_prior = 1.0 / cl->num_classes; + if (fabs(class_prob - uniform_prior) < ctx->cfg->min_prob_strength) { + continue; + } + + cl->class_log_probs[j] += log(class_prob); + } + + cl->processed_tokens++; + if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { + cl->text_tokens++; + } + + /* Per-token debug logging removed to reduce verbosity */ + } +} + +/* + * Multinomial Bayes classification with Fisher confidence + */ +static gboolean +bayes_classify_multiclass(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task) +{ + struct bayes_multiclass_closure cl; + rspamd_token_t *tok; + unsigned int i, j, text_tokens = 0; + int id; + struct rspamd_statfile *st; + rspamd_multiclass_result_t *result; + double *normalized_probs; + double max_log_prob = -INFINITY; + unsigned int winning_class_idx = 0; + double confidence; + + g_assert(ctx != NULL); + g_assert(tokens != NULL); + + /* Initialize multi-class closure */ + memset(&cl, 0, sizeof(cl)); + cl.task = task; + cl.cfg = ctx->cfg; + + /* Get class information from classifier config */ + if (!ctx->cfg->class_names) { + msg_debug_bayes("no class_names array in classifier config"); + return TRUE; /* Fall back to binary mode */ + } + if (ctx->cfg->class_names->len < 2) { + msg_debug_bayes("insufficient classes: %ud < 2", (unsigned int) ctx->cfg->class_names->len); + return TRUE; /* Fall back to binary mode */ + } + if (!ctx->cfg->class_names->pdata) { + msg_debug_bayes("class_names->pdata is NULL"); + return TRUE; /* Fall back to binary mode */ + } + + cl.num_classes = ctx->cfg->class_names->len; + cl.class_names = (char **) ctx->cfg->class_names->pdata; + + /* Debug: verify class names are accessible */ + msg_debug_bayes("multiclass setup: ctx->cfg->class_names=%p, len=%ud, pdata=%p", + ctx->cfg->class_names, (unsigned int) ctx->cfg->class_names->len, ctx->cfg->class_names->pdata); + msg_debug_bayes("multiclass setup: cl.num_classes=%ud, cl.class_names=%p", + cl.num_classes, cl.class_names); + cl.class_log_probs = g_alloca(cl.num_classes * sizeof(double)); + cl.class_learns = g_alloca(cl.num_classes * sizeof(uint64_t)); + + /* Initialize probabilities and get learning counts */ + for (i = 0; i < cl.num_classes; i++) { + cl.class_log_probs[i] = 0.0; + cl.class_learns[i] = 0; + } + + /* Collect learning counts for each class */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, int, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + + for (j = 0; j < cl.num_classes; j++) { + if (st->stcf->class_name && + strcmp(st->stcf->class_name, cl.class_names[j]) == 0) { + cl.class_learns[j] += st->backend->total_learns(task, + g_ptr_array_index(task->stat_runtimes, id), ctx->ctx); + break; + } + } + } + + /* Check minimum learns requirement - count viable classes */ + unsigned int viable_classes = 0; + if (ctx->cfg->min_learns > 0) { + for (i = 0; i < cl.num_classes; i++) { + if (cl.class_learns[i] >= ctx->cfg->min_learns) { + viable_classes++; + } + else { + msg_info_task("class %s excluded from classification: %uL learns < %ud minimum", + cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns); + } + } + + if (viable_classes == 0) { + msg_info_task("no classes have sufficient training samples for classification"); + return TRUE; + } + + msg_info_bayes("multiclass classification: %ud/%ud classes have sufficient learns", + viable_classes, cl.num_classes); + } + + /* Count text tokens */ + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { + text_tokens++; + } + } + + if (text_tokens == 0) { + msg_info_task("skipped classification as there are no text tokens. " + "Total tokens: %ud", + tokens->len); + return TRUE; + } + + /* Set meta token skip probability */ + if (text_tokens > tokens->len - text_tokens) { + cl.meta_skip_prob = 0.0; + } + else { + cl.meta_skip_prob = 1.0 - (double) text_tokens / tokens->len; + } + + /* Process all tokens */ + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + bayes_classify_token_multiclass(ctx, tok, &cl); + } + + if (cl.processed_tokens == 0) { + /* Debug: check why no tokens were processed */ + msg_debug_bayes("examining token values for debugging:"); + for (i = 0; i < MIN(tokens->len, 10); i++) { /* Check first 10 tokens */ + tok = g_ptr_array_index(tokens, i); + for (j = 0; j < ctx->statfiles_ids->len; j++) { + id = g_array_index(ctx->statfiles_ids, int, j); + if (tok->values[id] > 0) { + struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id); + msg_debug_bayes("token %ud: values[%d] = %.2f (class=%s, symbol=%s)", + i, id, tok->values[id], + st->stcf->class_name ? st->stcf->class_name : "unknown", + st->stcf->symbol); + } + } + } + + msg_info_bayes("no tokens found in bayes database " + "(%ud total tokens, %ud text tokens), ignore stats", + tokens->len, text_tokens); + return TRUE; + } + + if (ctx->cfg->min_tokens > 0 && + cl.text_tokens < (int) (ctx->cfg->min_tokens * 0.1)) { + msg_info_bayes("ignore bayes probability since we have " + "found too few text tokens: %uL (of %ud checked), " + "at least %d required", + cl.text_tokens, text_tokens, + (int) (ctx->cfg->min_tokens * 0.1)); + return TRUE; + } + + /* Normalize probabilities using softmax */ + normalized_probs = g_alloca(cl.num_classes * sizeof(double)); + + /* Find maximum for numerical stability - only consider classes with sufficient training */ + for (i = 0; i < cl.num_classes; i++) { + msg_debug_bayes("class %s, log_prob: %.2f", cl.class_names[i], cl.class_log_probs[i]); + /* Only consider classes that have sufficient training data */ + if (ctx->cfg->min_learns > 0 && cl.class_learns[i] < ctx->cfg->min_learns) { + msg_debug_bayes("skipping class %s in winner selection: %uL learns < %ud minimum", + cl.class_names[i], cl.class_learns[i], ctx->cfg->min_learns); + continue; + } + if (cl.class_log_probs[i] > max_log_prob) { + max_log_prob = cl.class_log_probs[i]; + winning_class_idx = i; + } + } + + /* Apply softmax normalization */ + double sum_exp = 0.0; + for (i = 0; i < cl.num_classes; i++) { + normalized_probs[i] = exp(cl.class_log_probs[i] - max_log_prob); + sum_exp += normalized_probs[i]; + } + + if (sum_exp > 0) { + for (i = 0; i < cl.num_classes; i++) { + normalized_probs[i] /= sum_exp; + } + } + else { + /* Fallback to uniform distribution */ + for (i = 0; i < cl.num_classes; i++) { + normalized_probs[i] = 1.0 / cl.num_classes; + } + } + + /* Calculate confidence using Fisher method for the winning class */ + if (max_log_prob > -300) { + if (max_log_prob > 0) { + /* Positive log prob means very strong evidence - high confidence */ + confidence = 0.95; /* High confidence for positive log probabilities */ + msg_debug_bayes("positive log_prob (%g), setting high confidence", max_log_prob); } else { - msg_debug_bayes("token(%s) %uL <?:?>: weight: %f, cf: %f, " - "total_count: %ud, " - "spam_count: %ud, ham_count: %ud," - "spam_prob: %.3f, ham_prob: %.3f, " - "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, " - "current spam probability: %.3f, current ham probability: %.3f", - token_type, - tok->data, - fw, w, total_count, spam_count, ham_count, - spam_prob, ham_prob, - bayes_spam_prob, bayes_ham_prob, - cl->spam_prob, cl->ham_prob); + /* Negative log prob - use Fisher method as intended */ + double fisher_result = inv_chi_square(task, max_log_prob, cl.processed_tokens); + confidence = 1.0 - fisher_result; + + msg_debug_bayes("fisher_result: %g, max_log_prob: %g, condition check: fisher_result > 0.999 = %s, max_log_prob > -50 = %s", + fisher_result, max_log_prob, + fisher_result > 0.999 ? "true" : "false", + max_log_prob > -50 ? "true" : "false"); + + /* Handle case where Fisher method indicates extreme confidence */ + if (fisher_result > 0.999 && max_log_prob > -100) { + /* Large magnitude negative log prob means strong evidence */ + confidence = 0.90; + msg_debug_bayes("extreme negative log_prob (%g), setting high confidence", max_log_prob); + } } } + else { + confidence = normalized_probs[winning_class_idx]; + } + + /* Create and store multiclass result */ + result = g_new0(rspamd_multiclass_result_t, 1); + result->class_names = g_new(char *, cl.num_classes); + result->probabilities = g_new(double, cl.num_classes); + result->num_classes = cl.num_classes; + result->winning_class = cl.class_names[winning_class_idx]; /* Reference, not copy */ + result->confidence = confidence; + + for (i = 0; i < cl.num_classes; i++) { + result->class_names[i] = g_strdup(cl.class_names[i]); + result->probabilities[i] = normalized_probs[i]; + } + + rspamd_task_set_multiclass_result(task, result); + + msg_info_bayes("MULTICLASS_RESULT: winning_class='%s', confidence=%.3f, normalized_prob=%.3f, tokens=%uL", + cl.class_names[winning_class_idx], confidence, + normalized_probs[winning_class_idx], cl.processed_tokens); + + /* Insert symbol for winning class if confidence is significant */ + if (confidence > 0.05) { + char sumbuf[32]; + double final_prob = rspamd_normalize_probability(confidence, 0.5); + + rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%", confidence * 100.0); + + /* Find the statfile for the winning class to get the symbol */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, int, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + + if (st->stcf->class_name && + strcmp(st->stcf->class_name, cl.class_names[winning_class_idx]) == 0) { + msg_info_bayes("SYMBOL_INSERT: symbol='%s', final_prob=%.3f, confidence_display='%s'", + st->stcf->symbol, final_prob, sumbuf); + rspamd_task_insert_result(task, st->stcf->symbol, final_prob, sumbuf); + break; + } + } + + msg_debug_bayes("multiclass classification: winning class '%s' with " + "probability %.3f, confidence %.3f, %uL tokens processed", + cl.class_names[winning_class_idx], + normalized_probs[winning_class_idx], + confidence, cl.processed_tokens); + } + else { + msg_info_bayes("SYMBOL_SKIPPED: confidence=%.3f <= 0.05, no symbol inserted", confidence); + } + + return TRUE; } @@ -279,6 +670,37 @@ bayes_classify(struct rspamd_classifier *ctx, g_assert(ctx != NULL); g_assert(tokens != NULL); + /* Check if this is a multi-class classifier */ + msg_debug_bayes("classification check: class_names=%p, len=%uz", + ctx->cfg->class_names, + ctx->cfg->class_names ? ctx->cfg->class_names->len : 0); + + if (ctx->cfg->class_names && ctx->cfg->class_names->len >= 2) { + /* Verify that at least one statfile has class_name set (indicating new multi-class config) */ + gboolean has_class_names = FALSE; + for (i = 0; i < ctx->statfiles_ids->len; i++) { + int id = g_array_index(ctx->statfiles_ids, int, i); + struct rspamd_statfile *st = g_ptr_array_index(ctx->ctx->statfiles, id); + msg_debug_bayes("checking statfile %s: class_name=%s, is_spam_converted=%s", + st->stcf->symbol, + st->stcf->class_name ? st->stcf->class_name : "NULL", + st->stcf->is_spam_converted ? "true" : "false"); + if (st->stcf->class_name) { + has_class_names = TRUE; + } + } + + msg_debug_bayes("has_class_names=%s", has_class_names ? "true" : "false"); + + if (has_class_names) { + msg_debug_bayes("using multiclass classification with %ud classes", + (unsigned int) ctx->cfg->class_names->len); + return bayes_classify_multiclass(ctx, tokens, task); + } + } + + /* Fall back to binary classification */ + msg_debug_bayes("using binary classification"); memset(&cl, 0, sizeof(cl)); cl.task = task; @@ -286,14 +708,14 @@ bayes_classify(struct rspamd_classifier *ctx, if (ctx->cfg->min_learns > 0) { if (ctx->ham_learns < ctx->cfg->min_learns) { msg_info_task("not classified as ham. The ham class needs more " - "training samples. Currently: %ul; minimum %ud required", + "training samples. Currently: %uL; minimum %ud required", ctx->ham_learns, ctx->cfg->min_learns); return TRUE; } if (ctx->spam_learns < ctx->cfg->min_learns) { msg_info_task("not classified as spam. The spam class needs more " - "training samples. Currently: %ul; minimum %ud required", + "training samples. Currently: %uL; minimum %ud required", ctx->spam_learns, ctx->cfg->min_learns); return TRUE; @@ -374,7 +796,7 @@ bayes_classify(struct rspamd_classifier *ctx, final_prob = (s + 1.0 - h) / 2.; msg_debug_bayes( "got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f," - " %L tokens processed of %ud total tokens;" + " %uL tokens processed of %ud total tokens;" " %uL text tokens found of %ud text tokens)", cl.ham_prob, h, @@ -549,3 +971,155 @@ bayes_learn_spam(struct rspamd_classifier *ctx, return TRUE; } + +gboolean +bayes_learn_class(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + const char *class_name, + gboolean unlearn, + GError **err) +{ + unsigned int i, j, total_cnt; + int id; + struct rspamd_statfile *st; + rspamd_token_t *tok; + gboolean incrementing; + unsigned int *class_counts = NULL; + struct rspamd_statfile **class_statfiles = NULL; + unsigned int num_classes = 0; + + g_assert(ctx != NULL); + g_assert(tokens != NULL); + g_assert(class_name != NULL); + + msg_info_bayes("LEARN_CLASS: class='%s', unlearn=%s, tokens=%ud", + class_name, unlearn ? "true" : "false", tokens->len); + + incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; + + /* Count classes and prepare arrays for multi-class learning */ + if (ctx->cfg->class_names && ctx->cfg->class_names->len > 0) { + num_classes = ctx->cfg->class_names->len; + class_counts = g_alloca(num_classes * sizeof(unsigned int)); + class_statfiles = g_alloca(num_classes * sizeof(struct rspamd_statfile *)); + memset(class_counts, 0, num_classes * sizeof(unsigned int)); + memset(class_statfiles, 0, num_classes * sizeof(struct rspamd_statfile *)); + } + + for (i = 0; i < tokens->len; i++) { + total_cnt = 0; + tok = g_ptr_array_index(tokens, i); + + /* Reset class counts for this token */ + if (num_classes > 0) { + memset(class_counts, 0, num_classes * sizeof(unsigned int)); + } + + for (j = 0; j < ctx->statfiles_ids->len; j++) { + id = g_array_index(ctx->statfiles_ids, int, j); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + + /* Determine if this statfile matches our target class */ + gboolean is_target_class = FALSE; + if (st->stcf->class_name) { + /* Multi-class: exact class name match */ + is_target_class = (strcmp(st->stcf->class_name, class_name) == 0); + } + else { + /* Legacy binary: map class_name to spam/ham */ + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + is_target_class = st->stcf->is_spam; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + is_target_class = !st->stcf->is_spam; + } + } + + if (is_target_class) { + /* Learning: increment the target class */ + if (incrementing) { + tok->values[id] = 1; + } + else { + tok->values[id]++; + } + total_cnt += tok->values[id]; + + /* Track class counts for debugging */ + if (num_classes > 0) { + for (unsigned int k = 0; k < num_classes; k++) { + const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k); + if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) { + class_counts[k] += tok->values[id]; + class_statfiles[k] = st; + break; + } + } + } + } + else { + /* Unlearning: decrement other classes if unlearn flag is set */ + if (tok->values[id] > 0 && unlearn) { + if (incrementing) { + tok->values[id] = -1; + } + else { + tok->values[id]--; + } + total_cnt += tok->values[id]; + + /* Track class counts for debugging */ + if (num_classes > 0) { + for (unsigned int k = 0; k < num_classes; k++) { + const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k); + if (st->stcf->class_name && strcmp(st->stcf->class_name, check_class) == 0) { + class_counts[k] += tok->values[id]; + class_statfiles[k] = st; + break; + } + } + } + } + else if (incrementing) { + tok->values[id] = 0; + } + } + } + + /* Debug logging */ + if (tok->t1 && tok->t2) { + if (num_classes > 0) { + GString *debug_str = g_string_new(""); + for (unsigned int k = 0; k < num_classes; k++) { + const char *check_class = (const char *) g_ptr_array_index(ctx->cfg->class_names, k); + g_string_append_printf(debug_str, "%s:%d ", check_class, class_counts[k]); + } + msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, " + "class_counts: %s", + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + tok->window_idx, total_cnt, debug_str->str); + g_string_free(debug_str, TRUE); + } + else { + msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, " + "class: %s", + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + tok->window_idx, total_cnt, class_name); + } + } + else { + msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, " + "class: %s", + tok->data, + tok->window_idx, total_cnt, class_name); + } + } + + return TRUE; +} diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h index 22978e673..cab658146 100644 --- a/src/libstat/classifiers/classifiers.h +++ b/src/libstat/classifiers/classifiers.h @@ -54,6 +54,13 @@ struct rspamd_stat_classifier { gboolean unlearn, GError **err); + gboolean (*learn_class_func)(struct rspamd_classifier *ctx, + GPtrArray *input, + struct rspamd_task *task, + const char *class_name, + gboolean unlearn, + GError **err); + void (*fin_func)(struct rspamd_classifier *cl); }; @@ -73,6 +80,13 @@ gboolean bayes_learn_spam(struct rspamd_classifier *ctx, gboolean unlearn, GError **err); +gboolean bayes_learn_class(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + const char *class_name, + gboolean unlearn, + GError **err); + void bayes_fin(struct rspamd_classifier *); /* Generic lua classifier */ diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx index 0de5cd094..afefeadcd 100644 --- a/src/libstat/learn_cache/redis_cache.cxx +++ b/src/libstat/learn_cache/redis_cache.cxx @@ -152,6 +152,33 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task, return (void *) ctx; } +/* Get class ID using rspamd_cryptobox_fast_hash */ +static uint64_t +rspamd_stat_cache_get_class_id(const char *class_name) +{ + if (!class_name) { + return 0; + } + + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + return 1; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + return 0; + } + else { + /* For other classes, use rspamd_cryptobox_fast_hash */ + uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0); + + /* Ensure we don't get 0 or 1 (reserved for ham/spam) */ + if (hash == 0 || hash == 1) { + hash += 2; + } + + return hash; + } +} + static int rspamd_stat_cache_checked(lua_State *L) { @@ -161,23 +188,39 @@ rspamd_stat_cache_checked(lua_State *L) if (res) { auto val = lua_tointeger(L, 3); - if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || - (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { - /* Already learned */ - msg_info_task("<%s> has been already " - "learned as %s, ignore it", - MESSAGE_FIELD(task, message_id), - (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); - task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + /* Get the class being learned */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (!autolearn_class) { + /* Fallback to binary flags for backward compatibility */ + if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) { + autolearn_class = "spam"; + } + else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + autolearn_class = "ham"; + } } - else { - /* Unlearn flag */ - task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + + if (autolearn_class) { + uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class); + + if ((uint64_t) val == expected_id) { + /* Already learned */ + msg_info_task("<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + autolearn_class); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + } + else { + /* Different class learned, unlearn flag */ + msg_debug_task("<%s> cached value %ld != expected %lu for class %s, will unlearn", + MESSAGE_FIELD(task, message_id), + val, expected_id, autolearn_class); + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + } } } - /* Ignore errors for now, as we can do nothing about them at the moment */ - return 0; } @@ -235,9 +278,20 @@ int rspamd_stat_cache_redis_learn(struct rspamd_task *task, lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref); rspamd_lua_task_push(L, task); lua_pushstring(L, h); - lua_pushboolean(L, is_spam); - if (lua_pcall(L, 3, 0, err_idx) != 0) { + /* Get the class being learned - prefer multiclass over binary */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (!autolearn_class) { + /* Fallback to binary flag for backward compatibility */ + autolearn_class = is_spam ? "spam" : "ham"; + } + + /* Push class name and class ID */ + lua_pushstring(L, autolearn_class); + uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class); + lua_pushinteger(L, class_id); + + if (lua_pcall(L, 4, 0, err_idx) != 0) { msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); return RSPAMD_LEARN_IGNORE; diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index f28922588..aa6111a8b 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -20,6 +20,7 @@ #include "task.h" #include "lua/lua_common.h" #include "contrib/libev/ev.h" +#include "libserver/word.h" #ifdef __cplusplus extern "C" { @@ -30,36 +31,14 @@ extern "C" { * High level statistics API */ -#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1u << 0) -#define RSPAMD_STAT_TOKEN_FLAG_META (1u << 1) -#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1u << 2) -#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1u << 3) -#define RSPAMD_STAT_TOKEN_FLAG_HEADER (1u << 4) -#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1u << 5) -#define RSPAMD_STAT_TOKEN_FLAG_UTF (1u << 6) -#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED (1u << 7) -#define RSPAMD_STAT_TOKEN_FLAG_STEMMED (1u << 8) -#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE (1u << 9) -#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD (1u << 10) -#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED (1u << 11) -#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES (1u << 12) -#define RSPAMD_STAT_TOKEN_FLAG_EMOJI (1u << 13) - -typedef struct rspamd_stat_token_s { - rspamd_ftok_t original; /* utf8 raw */ - rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */ - rspamd_ftok_t normalized; /* normalized and lowercased utf8 */ - rspamd_ftok_t stemmed; /* stemmed utf8 */ - unsigned int flags; -} rspamd_stat_token_t; #define RSPAMD_TOKEN_VALUE_TYPE float typedef struct token_node_s { uint64_t data; unsigned int window_idx; unsigned int flags; - rspamd_stat_token_t *t1; - rspamd_stat_token_t *t2; + rspamd_word_t *t1; + rspamd_word_t *t2; RSPAMD_TOKEN_VALUE_TYPE values[0]; } rspamd_token_t; @@ -129,6 +108,23 @@ rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task, GError **err); /** + * Learn task as a specific class, task must be processed prior to this call + * @param task task to learn + * @param class_name name of the class to learn (e.g., "spam", "ham", "transactional") + * @param L lua state + * @param classifier NULL to learn all classifiers, name to learn a specific one + * @param stage learning stage + * @param err error returned + * @return TRUE if task has been learned + */ +rspamd_stat_result_t rspamd_stat_learn_class(struct rspamd_task *task, + const char *class_name, + lua_State *L, + const char *classifier, + unsigned int stage, + GError **err); + +/** * Get the overall statistics for all statfile backends * @param cfg configuration * @param total_learns the total number of learns is stored here @@ -141,6 +137,43 @@ rspamd_stat_result_t rspamd_stat_statistics(struct rspamd_task *task, void rspamd_stat_unload(void); +/** + * Multi-class classification result structure + */ +typedef struct { + char **class_names; /**< Array of class names */ + double *probabilities; /**< Array of probabilities for each class */ + unsigned int num_classes; /**< Number of classes */ + const char *winning_class; /**< Name of the winning class (reference, not owned) */ + double confidence; /**< Confidence of the winning class */ +} rspamd_multiclass_result_t; + +/** + * Set multi-class classification result for a task + */ +void rspamd_task_set_multiclass_result(struct rspamd_task *task, + rspamd_multiclass_result_t *result); + +/** + * Get multi-class classification result from a task + */ +rspamd_multiclass_result_t *rspamd_task_get_multiclass_result(struct rspamd_task *task); + +/** + * Free multi-class result structure + */ +void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result); + +/** + * Set autolearn class for a task + */ +void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name); + +/** + * Get autolearn class from a task + */ +const char *rspamd_task_get_autolearn_class(struct rspamd_task *task); + #ifdef __cplusplus } #endif diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c index 8a5313df2..5ada7d468 100644 --- a/src/libstat/stat_config.c +++ b/src/libstat/stat_config.c @@ -28,6 +28,7 @@ static struct rspamd_stat_classifier lua_classifier = { .init_func = lua_classifier_init, .classify_func = lua_classifier_classify, .learn_spam_func = lua_classifier_learn_spam, + .learn_class_func = NULL, /* TODO: implement lua multi-class learning */ .fin_func = NULL, }; @@ -37,6 +38,7 @@ static struct rspamd_stat_classifier stat_classifiers[] = { .init_func = bayes_init, .classify_func = bayes_classify, .learn_spam_func = bayes_learn_spam, + .learn_class_func = bayes_learn_class, .fin_func = bayes_fin, }}; @@ -68,8 +70,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = { .dec_learns = rspamd_##eltn##_dec_learns, \ .get_stat = rspamd_##eltn##_get_stat, \ .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \ - .close = rspamd_##eltn##_close \ - } + .close = rspamd_##eltn##_close} #define RSPAMD_STAT_BACKEND_ELT_READONLY(nam, eltn) \ { \ .name = #nam, \ @@ -85,8 +86,7 @@ static struct rspamd_stat_tokenizer stat_tokenizers[] = { .dec_learns = NULL, \ .get_stat = rspamd_##eltn##_get_stat, \ .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \ - .close = rspamd_##eltn##_close \ - } + .close = rspamd_##eltn##_close} static struct rspamd_stat_backend stat_backends[] = { RSPAMD_STAT_BACKEND_ELT(mmap, mmaped_file), @@ -101,8 +101,7 @@ static struct rspamd_stat_backend stat_backends[] = { .runtime = rspamd_stat_cache_##eltn##_runtime, \ .check = rspamd_stat_cache_##eltn##_check, \ .learn = rspamd_stat_cache_##eltn##_learn, \ - .close = rspamd_stat_cache_##eltn##_close \ - } + .close = rspamd_stat_cache_##eltn##_close} static struct rspamd_stat_cache stat_caches[] = { RSPAMD_STAT_CACHE_ELT(sqlite3, sqlite3), diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 17caf4cc6..11b31decc 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,16 +32,89 @@ static const double similarity_threshold = 80.0; +void rspamd_task_set_multiclass_result(struct rspamd_task *task, rspamd_multiclass_result_t *result) +{ + g_assert(task != NULL); + g_assert(result != NULL); + + rspamd_mempool_set_variable(task->task_pool, "multiclass_bayes_result", result, + (rspamd_mempool_destruct_t) rspamd_multiclass_result_free); +} + +rspamd_multiclass_result_t * +rspamd_task_get_multiclass_result(struct rspamd_task *task) +{ + g_assert(task != NULL); + + return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool, + "multiclass_bayes_result"); +} + +void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result) +{ + if (result == NULL) { + return; + } + + g_free(result->class_names); + g_free(result->probabilities); + /* winning_class is a reference, not owned - don't free */ + g_free(result); +} + +void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name) +{ + g_assert(task != NULL); + g_assert(class_name != NULL); + + /* Store the class name in the mempool */ + const char *class_name_copy = rspamd_mempool_strdup(task->task_pool, class_name); + rspamd_mempool_set_variable(task->task_pool, "autolearn_class", + (gpointer) class_name_copy, NULL); + + /* Set the appropriate flags */ + task->flags |= RSPAMD_TASK_FLAG_LEARN_CLASS; + + /* For backward compatibility, also set binary flags */ + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + } +} + +const char * +rspamd_task_get_autolearn_class(struct rspamd_task *task) +{ + g_assert(task != NULL); + + if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) { + return (const char *) rspamd_mempool_get_variable(task->task_pool, "autolearn_class"); + } + + /* Fallback to binary flags for backward compatibility */ + if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) { + return "spam"; + } + else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + return "ham"; + } + + return NULL; +} + static void rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task) { - GArray *ar; - rspamd_stat_token_t elt; + rspamd_words_t *words; + rspamd_word_t elt; unsigned int i; lua_State *L = task->cfg->lua_state; - ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16); + words = rspamd_mempool_alloc(task->task_pool, sizeof(*words)); + kv_init(*words); memset(&elt, 0, sizeof(elt)); elt.flags = RSPAMD_STAT_TOKEN_FLAG_META; @@ -87,7 +160,7 @@ rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, elt.normalized.begin = elt.original.begin; elt.normalized.len = elt.original.len; - g_array_append_val(ar, elt); + kv_push_safe(rspamd_word_t, *words, elt, meta_words_error); } lua_pop(L, 1); @@ -99,17 +172,20 @@ rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, } - if (ar->len > 0) { + if (kv_size(*words) > 0) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - ar, + words, TRUE, "M", task->tokens); } - rspamd_mempool_add_destructor(task->task_pool, - rspamd_array_free_hard, ar); + return; +meta_words_error: + + msg_err("cannot process meta words for task" + "memory allocation error, skipping the remaining"); } /* @@ -134,8 +210,8 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) { - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { - reserved_len += part->utf_words->len; + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a) { + reserved_len += kv_size(part->utf_words); } /* XXX: normal window size */ reserved_len += 5; @@ -149,9 +225,9 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) { - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - part->utf_words, IS_TEXT_PART_UTF(part), + &part->utf_words, IS_TEXT_PART_UTF(part), NULL, task->tokens); } @@ -163,10 +239,10 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, } } - if (task->meta_words != NULL) { + if (task->meta_words.a) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - task->meta_words, + &task->meta_words, TRUE, "SUBJECT", task->tokens); @@ -390,18 +466,9 @@ rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx, } /* - * Do not classify a message if some class is missing + * Multi-class approach: don't check for missing classes + * Missing tokens naturally result in 0 probability */ - if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) { - msg_info_task("skip statistics as SPAM class is missing"); - - return; - } - if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) { - msg_info_task("skip statistics as HAM class is missing"); - - return; - } for (i = 0; i < st_ctx->classifiers->len; i++) { cl = g_ptr_array_index(st_ctx->classifiers, i); @@ -561,7 +628,24 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx, if (sel->cache && sel->cachecf) { rt = cl->cache->runtime(task, sel->cachecf, FALSE); - learn_res = cl->cache->check(task, spam, rt); + + /* For multi-class learning, determine spam boolean from class name if available */ + gboolean cache_spam = spam; /* Default to original spam parameter */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (autolearn_class) { + if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) { + cache_spam = TRUE; + } + else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) { + cache_spam = FALSE; + } + else { + /* For other classes, use a heuristic or default to spam for cache purposes */ + cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */ + } + } + + learn_res = cl->cache->check(task, cache_spam, rt); } if (learn_res == RSPAMD_LEARN_IGNORE) { @@ -654,9 +738,63 @@ rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx, continue; } - if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam, - task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { - learned = TRUE; + /* Check if classifier supports multi-class learning and if we should use it */ + if (cl->subrs->learn_class_func && cl->cfg->class_names && cl->cfg->class_names->len > 2) { + /* Multi-class learning: determine class name from task flags or autolearn result */ + const char *class_name = NULL; + + if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) { + /* Find spam class name */ + for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) { + const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k); + /* Look for statfile with this class that is spam */ + GList *cur = cl->cfg->statfiles; + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && stcf->is_spam) { + class_name = check_class; + break; + } + cur = g_list_next(cur); + } + if (class_name) break; + } + if (!class_name) class_name = "spam"; /* fallback */ + } + else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + /* Find ham class name */ + for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) { + const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k); + /* Look for statfile with this class that is ham */ + GList *cur = cl->cfg->statfiles; + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && !stcf->is_spam) { + class_name = check_class; + break; + } + cur = g_list_next(cur); + } + if (class_name) break; + } + if (!class_name) class_name = "ham"; /* fallback */ + } + else { + /* Fallback to spam/ham based on the spam parameter */ + class_name = spam ? "spam" : "ham"; + } + + if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + else { + /* Binary learning: use existing function */ + if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } } } @@ -755,9 +893,26 @@ rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx, backend_found = TRUE; if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) { - if (!!spam != !!st->stcf->is_spam) { - /* If we are not unlearning, then do not touch another class */ - continue; + /* For multiclass learning, check if this statfile has any tokens to learn */ + if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) { + /* Multiclass learning: only process statfiles that have tokens set up by the classifier */ + gboolean has_tokens = FALSE; + for (unsigned int k = 0; k < task->tokens->len && !has_tokens; k++) { + rspamd_token_t *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, k); + if (tok->values[id] != 0) { + has_tokens = TRUE; + } + } + if (!has_tokens) { + continue; + } + } + else { + /* Binary learning: use traditional spam/ham check */ + if (!!spam != !!st->stcf->is_spam) { + /* If we are not unlearning, then do not touch another class */ + continue; + } } } @@ -866,7 +1021,24 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx, if (cl->cache) { cache_run = cl->cache->runtime(task, cl->cachecf, TRUE); - cl->cache->learn(task, spam, cache_run); + + /* For multi-class learning, determine spam boolean from class name if available */ + gboolean cache_spam = spam; /* Default to original spam parameter */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (autolearn_class) { + if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) { + cache_spam = TRUE; + } + else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) { + cache_spam = FALSE; + } + else { + /* For other classes, use a heuristic or default to spam for cache purposes */ + cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */ + } + } + + cl->cache->learn(task, cache_spam, cache_run); } } @@ -875,6 +1047,218 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx, return res; } +static gboolean +rspamd_stat_classifiers_learn_class(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const char *classifier, + const char *class_name, + GError **err) +{ + struct rspamd_classifier *cl, *sel = NULL; + unsigned int i; + gboolean learned = FALSE, too_small = FALSE, too_large = FALSE; + + if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL && + *err == NULL) { + /* Do not learn twice */ + g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + class_name); + + return FALSE; + } + + /* Check whether we have learned that file */ + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + sel = cl; + + /* Now check max and min tokens */ + if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { + msg_info_task( + "<%s> contains less tokens than required for %s classifier: " + "%ud < %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->min_tokens); + too_small = TRUE; + continue; + } + else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { + msg_info_task( + "<%s> contains more tokens than allowed for %s classifier: " + "%ud > %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + too_large = TRUE; + continue; + } + + /* Use the new multi-class learning function if available */ + if (cl->subrs->learn_class_func) { + if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + else { + /* Fallback to binary learning with class name mapping */ + gboolean is_spam; + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + is_spam = TRUE; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + is_spam = FALSE; + } + else { + /* For unknown classes with binary classifier, skip */ + msg_info_task("skipping class '%s' for binary classifier %s", + class_name, cl->cfg->name); + continue; + } + + if (cl->subrs->learn_spam_func(cl, task->tokens, task, is_spam, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + + if (!learned && err && *err == NULL) { + if (too_large) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains more tokens than allowed for %s classifier: " + "%d > %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->max_tokens); + } + else if (too_small) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains less tokens than required for %s classifier: " + "%d < %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->min_tokens); + } + } + + return learned; +} + +rspamd_stat_result_t +rspamd_stat_learn_class(struct rspamd_task *task, + const char *class_name, + lua_State *L, + const char *classifier, + unsigned int stage, + GError **err) +{ + struct rspamd_stat_ctx *st_ctx; + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK; + + /* + * We assume now that a task has been already classified before + * coming to learn + */ + g_assert(RSPAMD_TASK_IS_CLASSIFIED(task)); + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + msg_debug_bayes("learn class stage %d has been called for class '%s'", stage, class_name); + + if (st_ctx->classifiers->len == 0) { + msg_debug_bayes("no classifiers defined"); + task->processed_stages |= stage; + return ret; + } + + if (task->message == NULL) { + ret = RSPAMD_STAT_PROCESS_ERROR; + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Trying to learn an empty message"); + } + + task->processed_stages |= stage; + return ret; + } + + if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) { + /* Process classifiers - determine spam boolean for compatibility */ + gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0); + rspamd_stat_preprocess(st_ctx, task, TRUE, spam); + + if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) { + msg_debug_bayes("cache check failed, skip learning"); + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN) { + /* Process classifiers */ + if (!rspamd_stat_classifiers_learn_class(st_ctx, task, classifier, + class_name, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when learning classifiers;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + + /* Process backends - determine spam boolean for compatibility */ + gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0); + if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when storing data on backend;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) { + /* Process backends - determine spam boolean for compatibility */ + gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0); + if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + + task->processed_stages |= stage; + + return ret; +} + rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task, gboolean spam, lua_State *L, const char *classifier, unsigned int stage, @@ -1035,12 +1419,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) if (mres) { if (mres->score > rspamd_task_get_required_score(task, mres)) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - + rspamd_task_set_autolearn_class(task, "spam"); ret = TRUE; } else if (mres->score < 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + rspamd_task_set_autolearn_class(task, "ham"); ret = TRUE; } } @@ -1072,12 +1455,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) if (mres) { if (mres->score >= spam_score) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - + rspamd_task_set_autolearn_class(task, "spam"); ret = TRUE; } else if (mres->score <= ham_score) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + rspamd_task_set_autolearn_class(task, "ham"); ret = TRUE; } } @@ -1113,11 +1495,16 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) /* We can have immediate results */ if (lua_ret) { if (strcmp(lua_ret, "ham") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + rspamd_task_set_autolearn_class(task, "ham"); ret = TRUE; } else if (strcmp(lua_ret, "spam") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + rspamd_task_set_autolearn_class(task, "spam"); + ret = TRUE; + } + else { + /* Multi-class: any other class name */ + rspamd_task_set_autolearn_class(task, lua_ret); ret = TRUE; } } @@ -1135,79 +1522,138 @@ rspamd_stat_check_autolearn(struct rspamd_task *task) } } else if (ucl_object_type(obj) == UCL_OBJECT) { - /* Try to find autolearn callback */ - if (cl->autolearn_cbref == 0) { - /* We don't have preprocessed cb id, so try to get it */ - if (!rspamd_lua_require_function(L, "lua_bayes_learn", - "autolearn")) { - msg_err_task("cannot get autolearn library from " - "`lua_bayes_learn`"); - } - else { - cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + /* Check if this is a multi-class autolearn configuration */ + const ucl_object_t *multiclass_obj = ucl_object_lookup(obj, "multiclass"); + + if (multiclass_obj && ucl_object_type(multiclass_obj) == UCL_OBJECT) { + /* Multi-class threshold-based autolearn */ + const ucl_object_t *thresholds_obj = ucl_object_lookup(multiclass_obj, "thresholds"); + + if (thresholds_obj && ucl_object_type(thresholds_obj) == UCL_OBJECT) { + /* Iterate through class thresholds */ + ucl_object_iter_t it = NULL; + const ucl_object_t *class_obj; + const char *class_name; + + while ((class_obj = ucl_object_iterate(thresholds_obj, &it, true))) { + class_name = ucl_object_key(class_obj); + + if (class_name && ucl_object_type(class_obj) == UCL_ARRAY && class_obj->len == 2) { + /* [min_score, max_score] for this class */ + const ucl_object_t *min_elt = ucl_array_find_index(class_obj, 0); + const ucl_object_t *max_elt = ucl_array_find_index(class_obj, 1); + + if ((ucl_object_type(min_elt) == UCL_FLOAT || ucl_object_type(min_elt) == UCL_INT) && + (ucl_object_type(max_elt) == UCL_FLOAT || ucl_object_type(max_elt) == UCL_INT)) { + + double min_score = ucl_object_todouble(min_elt); + double max_score = ucl_object_todouble(max_elt); + + if (mres && mres->score >= min_score && mres->score <= max_score) { + rspamd_task_set_autolearn_class(task, class_name); + ret = TRUE; + msg_debug_bayes("multiclass autolearn: score %.2f matches class '%s' [%.2f, %.2f]", + mres->score, class_name, min_score, max_score); + break; /* Stop at first matching class */ + } + } + } + } } } - - if (cl->autolearn_cbref != -1) { - lua_pushcfunction(L, &rspamd_lua_traceback); - err_idx = lua_gettop(L); - lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref); - - ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); - *ptask = task; - rspamd_lua_setclass(L, rspamd_task_classname, -1); - /* Push the whole object as well */ - ucl_object_push_lua(L, obj, true); - - if (lua_pcall(L, 2, 1, err_idx) != 0) { - msg_err_task("call to autolearn script failed: " - "%s", - lua_tostring(L, -1)); + else { + /* Try to find autolearn callback */ + if (cl->autolearn_cbref == 0) { + /* We don't have preprocessed cb id, so try to get it */ + if (!rspamd_lua_require_function(L, "lua_bayes_learn", + "autolearn")) { + msg_err_task("cannot get autolearn library from " + "`lua_bayes_learn`"); + } + else { + cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } } - else { - lua_ret = lua_tostring(L, -1); - if (lua_ret) { - if (strcmp(lua_ret, "ham") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; - ret = TRUE; - } - else if (strcmp(lua_ret, "spam") == 0) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - ret = TRUE; + if (cl->autolearn_cbref != -1) { + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref); + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, rspamd_task_classname, -1); + /* Push the whole object as well */ + ucl_object_push_lua(L, obj, true); + + if (lua_pcall(L, 2, 1, err_idx) != 0) { + msg_err_task("call to autolearn script failed: " + "%s", + lua_tostring(L, -1)); + } + else { + lua_ret = lua_tostring(L, -1); + + if (lua_ret) { + if (strcmp(lua_ret, "ham") == 0) { + rspamd_task_set_autolearn_class(task, "ham"); + ret = TRUE; + } + else if (strcmp(lua_ret, "spam") == 0) { + rspamd_task_set_autolearn_class(task, "spam"); + ret = TRUE; + } + else { + /* Multi-class: any other class name */ + rspamd_task_set_autolearn_class(task, lua_ret); + ret = TRUE; + } } } - } - lua_settop(L, err_idx - 1); + lua_settop(L, err_idx - 1); + } } - } - if (ret) { - /* Do not autolearn if we have this symbol already */ - if (rspamd_stat_has_classifier_symbols(task, mres, cl)) { - ret = FALSE; - task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM | - RSPAMD_TASK_FLAG_LEARN_SPAM); - } - else if (mres != NULL) { - if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { - msg_info_task("<%s>: autolearn ham for classifier " - "'%s' as message's " - "score is negative: %.2f", - MESSAGE_FIELD(task, message_id), cl->cfg->name, - mres->score); - } - else { - msg_info_task("<%s>: autolearn spam for classifier " - "'%s' as message's " - "action is reject, score: %.2f", - MESSAGE_FIELD(task, message_id), cl->cfg->name, - mres->score); + if (ret) { + /* Do not autolearn if we have this symbol already */ + if (rspamd_stat_has_classifier_symbols(task, mres, cl)) { + ret = FALSE; + task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM | + RSPAMD_TASK_FLAG_LEARN_SPAM | + RSPAMD_TASK_FLAG_LEARN_CLASS); + /* Clear the autolearn class from mempool */ + rspamd_mempool_set_variable(task->task_pool, "autolearn_class", NULL, NULL); } + else if (mres != NULL) { + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + + if (autolearn_class) { + if (strcmp(autolearn_class, "ham") == 0) { + msg_info_task("<%s>: autolearn ham for classifier " + "'%s' as message's " + "score is negative: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + else if (strcmp(autolearn_class, "spam") == 0) { + msg_info_task("<%s>: autolearn spam for classifier " + "'%s' as message's " + "action is reject, score: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + else { + msg_info_task("<%s>: autolearn class '%s' for classifier " + "'%s', score: %.2f", + MESSAGE_FIELD(task, message_id), autolearn_class, + cl->cfg->name, mres->score); + } + } - task->classifier = cl->cfg->name; - break; + task->classifier = cl->cfg->name; + break; + } } } } diff --git a/src/libstat/tokenizers/custom_tokenizer.h b/src/libstat/tokenizers/custom_tokenizer.h new file mode 100644 index 000000000..bc173a1da --- /dev/null +++ b/src/libstat/tokenizers/custom_tokenizer.h @@ -0,0 +1,177 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_CUSTOM_TOKENIZER_H +#define RSPAMD_CUSTOM_TOKENIZER_H + +/* Check if we're being included by internal Rspamd code or external plugins */ +#ifdef RSPAMD_TOKENIZER_INTERNAL +/* Internal Rspamd usage - use the full headers */ +#include "config.h" +#include "ucl.h" +#include "libserver/word.h" +#else +/* External plugin usage - use standalone types */ +#include "rspamd_tokenizer_types.h" +/* Forward declaration for UCL object - plugins should include ucl.h if needed */ +typedef struct ucl_object_s ucl_object_t; +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define RSPAMD_CUSTOM_TOKENIZER_API_VERSION 1 + +/** + * Tokenization result - compatible with both internal and external usage + */ +typedef rspamd_words_t rspamd_tokenizer_result_t; + +/** + * Custom tokenizer API that must be implemented by language-specific tokenizer plugins + * All functions use only plain C types to ensure clean boundaries + */ +typedef struct rspamd_custom_tokenizer_api { + /* API version for compatibility checking */ + unsigned int api_version; + + /* Name of the tokenizer (e.g., "japanese_mecab") */ + const char *name; + + /** + * Global initialization function called once when the tokenizer is loaded + * @param config UCL configuration object for this tokenizer (may be NULL) + * @param error_buf Buffer for error message (at least 256 bytes) + * @return 0 on success, non-zero on failure + */ + int (*init)(const ucl_object_t *config, char *error_buf, size_t error_buf_size); + + /** + * Global cleanup function called when the tokenizer is unloaded + */ + void (*deinit)(void); + + /** + * Quick language detection to check if this tokenizer can handle the text + * @param text UTF-8 text to analyze + * @param len Length of the text in bytes + * @return Confidence score 0.0-1.0, or -1.0 if cannot handle + */ + double (*detect_language)(const char *text, size_t len); + + /** + * Main tokenization function + * @param text UTF-8 text to tokenize + * @param len Length of the text in bytes + * @param result Output kvec to fill with rspamd_word_t elements + * @return 0 on success, non-zero on failure + * + * The tokenizer should allocate result->a using its own allocator + * Rspamd will call cleanup_result() to free it after processing + */ + int (*tokenize)(const char *text, size_t len, + rspamd_tokenizer_result_t *result); + + /** + * Cleanup the result from tokenize() + * @param result Result kvec returned by tokenize() + * + * This function should free result->a using the same allocator + * that was used in tokenize() and reset the kvec fields. + * This ensures proper memory management across DLL boundaries. + * Note: This does NOT free the result structure itself, only its contents. + */ + void (*cleanup_result)(rspamd_tokenizer_result_t *result); + + /** + * Optional: Get language hint for better language detection + * @return Language code (e.g., "ja", "zh") or NULL + */ + const char *(*get_language_hint)(void); + + /** + * Optional: Get minimum confidence threshold for this tokenizer + * @return Minimum confidence (0.0-1.0) or -1.0 to use default + */ + double (*get_min_confidence)(void); + +} rspamd_custom_tokenizer_api_t; + +/** + * Entry point function that plugins must export + * Must be named "rspamd_tokenizer_get_api" + */ +typedef const rspamd_custom_tokenizer_api_t *(*rspamd_tokenizer_get_api_func)(void); + +/* Internal Rspamd structures - not exposed to plugins */ +#ifdef RSPAMD_TOKENIZER_INTERNAL + +/** + * Custom tokenizer instance + */ +struct rspamd_custom_tokenizer { + char *name; /* Tokenizer name from config */ + char *path; /* Path to .so file */ + void *handle; /* dlopen handle */ + const rspamd_custom_tokenizer_api_t *api; /* API functions */ + double priority; /* Detection priority */ + double min_confidence; /* Minimum confidence threshold */ + gboolean enabled; /* Is tokenizer enabled */ + ucl_object_t *config; /* Tokenizer-specific config */ +}; + +/** + * Tokenizer manager structure + */ +struct rspamd_tokenizer_manager { + GHashTable *tokenizers; /* name -> rspamd_custom_tokenizer */ + GArray *detection_order; /* Ordered by priority */ + rspamd_mempool_t *pool; + double default_threshold; /* Default confidence threshold */ +}; + +/* Manager functions */ +struct rspamd_tokenizer_manager *rspamd_tokenizer_manager_new(rspamd_mempool_t *pool); +void rspamd_tokenizer_manager_destroy(struct rspamd_tokenizer_manager *mgr); + +gboolean rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr, + const char *name, + const ucl_object_t *config, + GError **err); + +struct rspamd_custom_tokenizer *rspamd_tokenizer_manager_detect( + struct rspamd_tokenizer_manager *mgr, + const char *text, size_t len, + double *confidence, + const char *lang_hint, + const char **detected_lang_hint); + +/* Helper function to tokenize with exceptions handling */ +rspamd_tokenizer_result_t *rspamd_custom_tokenizer_tokenize_with_exceptions( + struct rspamd_custom_tokenizer *tokenizer, + const char *text, + gsize len, + GList *exceptions, + rspamd_mempool_t *pool); + +#endif /* RSPAMD_TOKENIZER_INTERNAL */ + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_CUSTOM_TOKENIZER_H */ diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c index 0bc3414a5..360c71d36 100644 --- a/src/libstat/tokenizers/osb.c +++ b/src/libstat/tokenizers/osb.c @@ -21,6 +21,7 @@ #include "tokenizers.h" #include "stat_internal.h" #include "libmime/lang_detection.h" +#include "libserver/word.h" /* Size for features pipe */ #define DEFAULT_FEATURE_WINDOW_SIZE 2 @@ -268,7 +269,7 @@ struct token_pipe_entry { int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result) @@ -282,7 +283,7 @@ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, gsize token_size; unsigned int processed = 0, i, w, window_size, token_flags = 0; - if (words == NULL) { + if (words == NULL || !words->a) { return FALSE; } @@ -306,8 +307,8 @@ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, sizeof(RSPAMD_TOKEN_VALUE_TYPE) * ctx->statfiles->len; g_assert(token_size > 0); - for (w = 0; w < words->len; w++) { - token = &g_array_index(words, rspamd_stat_token_t, w); + for (w = 0; w < kv_size(*words); w++) { + token = &kv_A(*words, w); token_flags = token->flags; const char *begin; gsize len; diff --git a/src/libstat/tokenizers/rspamd_tokenizer_types.h b/src/libstat/tokenizers/rspamd_tokenizer_types.h new file mode 100644 index 000000000..eb8518290 --- /dev/null +++ b/src/libstat/tokenizers/rspamd_tokenizer_types.h @@ -0,0 +1,89 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_TOKENIZER_TYPES_H +#define RSPAMD_TOKENIZER_TYPES_H + +/* + * Standalone type definitions for custom tokenizers + * This header is completely self-contained and does not depend on any external libraries. + * Custom tokenizers should include only this header to get access to all necessary types. + */ + +#include <stdint.h> +#include <stddef.h> + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Basic string token structure + */ +typedef struct rspamd_ftok { + size_t len; + const char *begin; +} rspamd_ftok_t; + +/** + * Unicode string token structure + */ +typedef struct rspamd_ftok_unicode { + size_t len; + const uint32_t *begin; +} rspamd_ftok_unicode_t; + +/* Word flags */ +#define RSPAMD_WORD_FLAG_TEXT (1u << 0u) +#define RSPAMD_WORD_FLAG_META (1u << 1u) +#define RSPAMD_WORD_FLAG_LUA_META (1u << 2u) +#define RSPAMD_WORD_FLAG_EXCEPTION (1u << 3u) +#define RSPAMD_WORD_FLAG_HEADER (1u << 4u) +#define RSPAMD_WORD_FLAG_UNIGRAM (1u << 5u) +#define RSPAMD_WORD_FLAG_UTF (1u << 6u) +#define RSPAMD_WORD_FLAG_NORMALISED (1u << 7u) +#define RSPAMD_WORD_FLAG_STEMMED (1u << 8u) +#define RSPAMD_WORD_FLAG_BROKEN_UNICODE (1u << 9u) +#define RSPAMD_WORD_FLAG_STOP_WORD (1u << 10u) +#define RSPAMD_WORD_FLAG_SKIPPED (1u << 11u) +#define RSPAMD_WORD_FLAG_INVISIBLE_SPACES (1u << 12u) +#define RSPAMD_WORD_FLAG_EMOJI (1u << 13u) + +/** + * Word structure + */ +typedef struct rspamd_word { + rspamd_ftok_t original; + rspamd_ftok_unicode_t unicode; + rspamd_ftok_t normalized; + rspamd_ftok_t stemmed; + unsigned int flags; +} rspamd_word_t; + +/** + * Array of words + */ +typedef struct rspamd_words { + rspamd_word_t *a; + size_t n; + size_t m; +} rspamd_words_t; + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_TOKENIZER_TYPES_H */ diff --git a/src/libstat/tokenizers/tokenizer_manager.c b/src/libstat/tokenizers/tokenizer_manager.c new file mode 100644 index 000000000..e6fb5e8d8 --- /dev/null +++ b/src/libstat/tokenizers/tokenizer_manager.c @@ -0,0 +1,500 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config.h" +#include "tokenizers.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "custom_tokenizer.h" +#include "libutil/util.h" +#include "libserver/logger.h" +#include <dlfcn.h> + +#define msg_err_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_warn_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_info_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_debug_tokenizer(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_tokenizer_log_id, "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(tokenizer) + +static void +rspamd_custom_tokenizer_dtor(gpointer p) +{ + struct rspamd_custom_tokenizer *tok = p; + + if (tok) { + if (tok->api && tok->api->deinit) { + tok->api->deinit(); + } + + if (tok->handle) { + dlclose(tok->handle); + } + + if (tok->config) { + ucl_object_unref(tok->config); + } + + g_free(tok->name); + g_free(tok->path); + g_free(tok); + } +} + +static int +rspamd_custom_tokenizer_priority_cmp(gconstpointer a, gconstpointer b) +{ + const struct rspamd_custom_tokenizer *t1 = *(const struct rspamd_custom_tokenizer **) a; + const struct rspamd_custom_tokenizer *t2 = *(const struct rspamd_custom_tokenizer **) b; + + /* Higher priority first */ + if (t1->priority > t2->priority) { + return -1; + } + else if (t1->priority < t2->priority) { + return 1; + } + + return 0; +} + +struct rspamd_tokenizer_manager * +rspamd_tokenizer_manager_new(rspamd_mempool_t *pool) +{ + struct rspamd_tokenizer_manager *mgr; + + mgr = rspamd_mempool_alloc0(pool, sizeof(*mgr)); + mgr->pool = pool; + mgr->tokenizers = g_hash_table_new_full(rspamd_strcase_hash, + rspamd_strcase_equal, + NULL, + rspamd_custom_tokenizer_dtor); + mgr->detection_order = g_array_new(FALSE, FALSE, sizeof(struct rspamd_custom_tokenizer *)); + mgr->default_threshold = 0.7; /* Default confidence threshold */ + + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) g_hash_table_unref, + mgr->tokenizers); + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) rspamd_array_free_hard, + mgr->detection_order); + + msg_info_tokenizer("created custom tokenizer manager with default confidence threshold %.3f", + mgr->default_threshold); + + return mgr; +} + +void rspamd_tokenizer_manager_destroy(struct rspamd_tokenizer_manager *mgr) +{ + /* Cleanup is handled by memory pool destructors */ +} + +gboolean +rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr, + const char *name, + const ucl_object_t *config, + GError **err) +{ + struct rspamd_custom_tokenizer *tok; + const ucl_object_t *elt; + rspamd_tokenizer_get_api_func get_api; + const rspamd_custom_tokenizer_api_t *api; + void *handle; + const char *path; + gboolean enabled = TRUE; + double priority = 50.0; + char error_buf[256]; + + g_assert(mgr != NULL); + g_assert(name != NULL); + g_assert(config != NULL); + + msg_info_tokenizer("starting to load custom tokenizer '%s'", name); + + /* Check if enabled */ + elt = ucl_object_lookup(config, "enabled"); + if (elt && ucl_object_type(elt) == UCL_BOOLEAN) { + enabled = ucl_object_toboolean(elt); + } + + if (!enabled) { + msg_info_tokenizer("custom tokenizer '%s' is disabled", name); + return TRUE; + } + + /* Get path */ + elt = ucl_object_lookup(config, "path"); + if (!elt || ucl_object_type(elt) != UCL_STRING) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "missing 'path' for tokenizer %s", name); + return FALSE; + } + path = ucl_object_tostring(elt); + msg_info_tokenizer("custom tokenizer '%s' will be loaded from path: %s", name, path); + + /* Get priority */ + elt = ucl_object_lookup(config, "priority"); + if (elt) { + priority = ucl_object_todouble(elt); + } + msg_info_tokenizer("custom tokenizer '%s' priority set to %.1f", name, priority); + + /* Load the shared library */ + msg_info_tokenizer("loading shared library for custom tokenizer '%s'", name); + handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "cannot load tokenizer %s from %s: %s", + name, path, dlerror()); + return FALSE; + } + msg_info_tokenizer("successfully loaded shared library for custom tokenizer '%s'", name); + + /* Get the API entry point */ + msg_info_tokenizer("looking up API entry point for custom tokenizer '%s'", name); + get_api = (rspamd_tokenizer_get_api_func) dlsym(handle, "rspamd_tokenizer_get_api"); + if (!get_api) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "cannot find entry point in %s: %s", + path, dlerror()); + return FALSE; + } + + /* Get the API */ + msg_info_tokenizer("calling API entry point for custom tokenizer '%s'", name); + api = get_api(); + if (!api) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "tokenizer %s returned NULL API", name); + return FALSE; + } + msg_info_tokenizer("successfully obtained API from custom tokenizer '%s'", name); + + /* Check API version */ + msg_info_tokenizer("checking API version for custom tokenizer '%s' (got %u, expected %u)", + name, api->api_version, RSPAMD_CUSTOM_TOKENIZER_API_VERSION); + if (api->api_version != RSPAMD_CUSTOM_TOKENIZER_API_VERSION) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "tokenizer %s has incompatible API version %u (expected %u)", + name, api->api_version, RSPAMD_CUSTOM_TOKENIZER_API_VERSION); + return FALSE; + } + + /* Create tokenizer instance */ + tok = g_malloc0(sizeof(*tok)); + tok->name = g_strdup(name); + tok->path = g_strdup(path); + tok->handle = handle; + tok->api = api; + tok->priority = priority; + tok->enabled = enabled; + + /* Get tokenizer config */ + elt = ucl_object_lookup(config, "config"); + if (elt) { + tok->config = ucl_object_ref(elt); + } + + /* Get minimum confidence */ + if (api->get_min_confidence) { + tok->min_confidence = api->get_min_confidence(); + msg_info_tokenizer("custom tokenizer '%s' provides minimum confidence threshold: %.3f", + name, tok->min_confidence); + } + else { + tok->min_confidence = mgr->default_threshold; + msg_info_tokenizer("custom tokenizer '%s' using default confidence threshold: %.3f", + name, tok->min_confidence); + } + + /* Initialize the tokenizer */ + if (api->init) { + msg_info_tokenizer("initializing custom tokenizer '%s'", name); + error_buf[0] = '\0'; + if (api->init(tok->config, error_buf, sizeof(error_buf)) != 0) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "failed to initialize tokenizer %s: %s", + name, error_buf[0] ? error_buf : "unknown error"); + rspamd_custom_tokenizer_dtor(tok); + return FALSE; + } + msg_info_tokenizer("successfully initialized custom tokenizer '%s'", name); + } + else { + msg_info_tokenizer("custom tokenizer '%s' does not require initialization", name); + } + + /* Add to manager */ + g_hash_table_insert(mgr->tokenizers, tok->name, tok); + g_array_append_val(mgr->detection_order, tok); + + /* Re-sort by priority */ + g_array_sort(mgr->detection_order, rspamd_custom_tokenizer_priority_cmp); + msg_info_tokenizer("custom tokenizer '%s' registered and sorted by priority (total tokenizers: %u)", + name, mgr->detection_order->len); + + msg_info_tokenizer("successfully loaded custom tokenizer '%s' (priority %.1f) from %s", + name, priority, path); + + return TRUE; +} + +struct rspamd_custom_tokenizer * +rspamd_tokenizer_manager_detect(struct rspamd_tokenizer_manager *mgr, + const char *text, size_t len, + double *confidence, + const char *lang_hint, + const char **detected_lang_hint) +{ + struct rspamd_custom_tokenizer *tok, *best_tok = NULL; + double conf, best_conf = 0.0; + unsigned int i; + + g_assert(mgr != NULL); + g_assert(text != NULL); + + msg_debug_tokenizer("starting tokenizer detection for text of length %zu", len); + + if (confidence) { + *confidence = 0.0; + } + + if (detected_lang_hint) { + *detected_lang_hint = NULL; + } + + /* If we have a language hint, try to find a tokenizer for that language first */ + if (lang_hint) { + msg_info_tokenizer("trying to find tokenizer for language hint: %s", lang_hint); + for (i = 0; i < mgr->detection_order->len; i++) { + tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i); + + if (!tok->enabled || !tok->api->get_language_hint) { + continue; + } + + /* Check if this tokenizer handles the hinted language */ + const char *tok_lang = tok->api->get_language_hint(); + if (tok_lang && g_ascii_strcasecmp(tok_lang, lang_hint) == 0) { + msg_info_tokenizer("found tokenizer '%s' for language hint '%s'", tok->name, lang_hint); + /* Found a tokenizer for this language, check if it actually detects it */ + if (tok->api->detect_language) { + conf = tok->api->detect_language(text, len); + msg_info_tokenizer("tokenizer '%s' confidence for hinted language: %.3f (threshold: %.3f)", + tok->name, conf, tok->min_confidence); + if (conf >= tok->min_confidence) { + /* Use this tokenizer */ + msg_info_tokenizer("using tokenizer '%s' for language hint '%s' with confidence %.3f", + tok->name, lang_hint, conf); + if (confidence) { + *confidence = conf; + } + if (detected_lang_hint) { + *detected_lang_hint = tok_lang; + } + return tok; + } + } + } + } + msg_info_tokenizer("no suitable tokenizer found for language hint '%s', falling back to general detection", lang_hint); + } + + /* Try each tokenizer in priority order */ + msg_info_tokenizer("trying %u tokenizers for general detection", mgr->detection_order->len); + for (i = 0; i < mgr->detection_order->len; i++) { + tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i); + + if (!tok->enabled || !tok->api->detect_language) { + msg_debug_tokenizer("skipping tokenizer '%s' (enabled: %s, has detect_language: %s)", + tok->name, tok->enabled ? "yes" : "no", + tok->api->detect_language ? "yes" : "no"); + continue; + } + + conf = tok->api->detect_language(text, len); + msg_info_tokenizer("tokenizer '%s' detection confidence: %.3f (threshold: %.3f, current best: %.3f)", + tok->name, conf, tok->min_confidence, best_conf); + + if (conf > best_conf && conf >= tok->min_confidence) { + best_conf = conf; + best_tok = tok; + msg_info_tokenizer("tokenizer '%s' is new best with confidence %.3f", tok->name, best_conf); + + /* Early exit if very confident */ + if (conf >= 0.95) { + msg_info_tokenizer("very high confidence (%.3f >= 0.95), using tokenizer '%s' immediately", + conf, tok->name); + break; + } + } + } + + if (best_tok) { + msg_info_tokenizer("selected tokenizer '%s' with confidence %.3f", best_tok->name, best_conf); + if (confidence) { + *confidence = best_conf; + } + + if (detected_lang_hint && best_tok->api->get_language_hint) { + *detected_lang_hint = best_tok->api->get_language_hint(); + msg_info_tokenizer("detected language hint: %s", *detected_lang_hint); + } + } + else { + msg_info_tokenizer("no suitable tokenizer found during detection"); + } + + return best_tok; +} + +/* Helper function to tokenize with a custom tokenizer handling exceptions */ +rspamd_tokenizer_result_t * +rspamd_custom_tokenizer_tokenize_with_exceptions( + struct rspamd_custom_tokenizer *tokenizer, + const char *text, + gsize len, + GList *exceptions, + rspamd_mempool_t *pool) +{ + rspamd_tokenizer_result_t *words; + rspamd_tokenizer_result_t result; + struct rspamd_process_exception *ex; + GList *cur_ex = exceptions; + gsize pos = 0; + unsigned int i; + int ret; + + /* Allocate result kvec in pool */ + words = rspamd_mempool_alloc(pool, sizeof(*words)); + kv_init(*words); + + /* If no exceptions, tokenize the whole text */ + if (!exceptions) { + kv_init(result); + + ret = tokenizer->api->tokenize(text, len, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result to output */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + kv_push(rspamd_word_t, *words, tok); + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + + return words; + } + + /* Process text with exceptions */ + while (pos < len && cur_ex) { + ex = (struct rspamd_process_exception *) cur_ex->data; + + /* Tokenize text before exception */ + if (ex->pos > pos) { + gsize segment_len = ex->pos - pos; + kv_init(result); + + ret = tokenizer->api->tokenize(text + pos, segment_len, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result, adjusting positions for segment offset */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + + /* Adjust pointers to point to the original text */ + gsize offset_in_segment = tok.original.begin - (text + pos); + if (offset_in_segment < segment_len) { + tok.original.begin = text + pos + offset_in_segment; + /* Ensure we don't go past the exception boundary */ + if (tok.original.begin + tok.original.len <= text + ex->pos) { + kv_push(rspamd_word_t, *words, tok); + } + } + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + } + + /* Add exception as a special token */ + rspamd_word_t ex_tok; + memset(&ex_tok, 0, sizeof(ex_tok)); + + if (ex->type == RSPAMD_EXCEPTION_URL) { + ex_tok.original.begin = "!!EX!!"; + ex_tok.original.len = 6; + } + else { + ex_tok.original.begin = text + ex->pos; + ex_tok.original.len = ex->len; + } + ex_tok.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + kv_push(rspamd_word_t, *words, ex_tok); + + /* Move past exception */ + pos = ex->pos + ex->len; + cur_ex = g_list_next(cur_ex); + } + + /* Process remaining text after last exception */ + if (pos < len) { + kv_init(result); + + ret = tokenizer->api->tokenize(text + pos, len - pos, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result, adjusting positions for segment offset */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + + /* Adjust pointers to point to the original text */ + gsize offset_in_segment = tok.original.begin - (text + pos); + if (offset_in_segment < (len - pos)) { + tok.original.begin = text + pos + offset_in_segment; + kv_push(rspamd_word_t, *words, tok); + } + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + } + + return words; +} diff --git a/src/libstat/tokenizers/tokenizers.c b/src/libstat/tokenizers/tokenizers.c index 0ea1bcfc6..8a9f42992 100644 --- a/src/libstat/tokenizers/tokenizers.c +++ b/src/libstat/tokenizers/tokenizers.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ #include "contrib/mumhash/mum.h" #include "libmime/lang_detection.h" #include "libstemmer.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "custom_tokenizer.h" #include <unicode/utf8.h> #include <unicode/uchar.h> @@ -35,8 +37,8 @@ #include <math.h> -typedef gboolean (*token_get_function)(rspamd_stat_token_t *buf, char const **pos, - rspamd_stat_token_t *token, +typedef gboolean (*token_get_function)(rspamd_word_t *buf, char const **pos, + rspamd_word_t *token, GList **exceptions, gsize *rl, gboolean check_signature); const char t_delimiters[256] = { @@ -69,8 +71,8 @@ const char t_delimiters[256] = { /* Get next word from specified f_str_t buf */ static gboolean -rspamd_tokenizer_get_word_raw(rspamd_stat_token_t *buf, - char const **cur, rspamd_stat_token_t *token, +rspamd_tokenizer_get_word_raw(rspamd_word_t *buf, + char const **cur, rspamd_word_t *token, GList **exceptions, gsize *rl, gboolean unused) { gsize remain, pos; @@ -164,7 +166,7 @@ rspamd_tokenize_check_limit(gboolean decay, unsigned int nwords, uint64_t *hv, uint64_t *prob, - const rspamd_stat_token_t *token, + const rspamd_word_t *token, gssize remain, gssize total) { @@ -242,9 +244,9 @@ rspamd_utf_word_valid(const unsigned char *text, const unsigned char *end, } while (0) static inline void -rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) +rspamd_tokenize_exception(struct rspamd_process_exception *ex, rspamd_words_t *res) { - rspamd_stat_token_t token; + rspamd_word_t token; memset(&token, 0, sizeof(token)); @@ -253,7 +255,7 @@ rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) token.original.len = sizeof("!!EX!!") - 1; token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, exception_error); token.flags = 0; } else if (ex->type == RSPAMD_EXCEPTION_URL) { @@ -271,28 +273,33 @@ rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) } token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, exception_error); token.flags = 0; } + return; + +exception_error: + /* On error, just skip this exception token */ + return; } -GArray * +rspamd_words_t * rspamd_tokenize_text(const char *text, gsize len, const UText *utxt, enum rspamd_tokenize_type how, struct rspamd_config *cfg, GList *exceptions, uint64_t *hash, - GArray *cur_words, + rspamd_words_t *output_kvec, rspamd_mempool_t *pool) { - rspamd_stat_token_t token, buf; + rspamd_word_t token, buf; const char *pos = NULL; gsize l = 0; - GArray *res; + rspamd_words_t *res; GList *cur = exceptions; - unsigned int min_len = 0, max_len = 0, word_decay = 0, initial_size = 128; + unsigned int min_len = 0, max_len = 0, word_decay = 0; uint64_t hv = 0; gboolean decay = FALSE, long_text_mode = FALSE; uint64_t prob = 0; @@ -300,9 +307,12 @@ rspamd_tokenize_text(const char *text, gsize len, static const gsize long_text_limit = 1 * 1024 * 1024; static const ev_tstamp max_exec_time = 0.2; /* 200 ms */ ev_tstamp start; + struct rspamd_custom_tokenizer *custom_tok = NULL; + double custom_confidence = 0.0; + const char *detected_lang = NULL; if (text == NULL) { - return cur_words; + return output_kvec; } if (len > long_text_limit) { @@ -323,15 +333,59 @@ rspamd_tokenize_text(const char *text, gsize len, min_len = cfg->min_word_len; max_len = cfg->max_word_len; word_decay = cfg->words_decay; - initial_size = word_decay * 2; } - if (!cur_words) { - res = g_array_sized_new(FALSE, FALSE, sizeof(rspamd_stat_token_t), - initial_size); + if (!output_kvec) { + res = pool ? rspamd_mempool_alloc0(pool, sizeof(*res)) : g_malloc0(sizeof(*res)); + ; } else { - res = cur_words; + res = output_kvec; + } + + /* Try custom tokenizers first if we're in UTF mode */ + if (cfg && cfg->tokenizer_manager && how == RSPAMD_TOKENIZE_UTF && utxt != NULL) { + custom_tok = rspamd_tokenizer_manager_detect( + cfg->tokenizer_manager, + text, len, + &custom_confidence, + NULL, /* no input language hint */ + &detected_lang); + + if (custom_tok && custom_confidence >= custom_tok->min_confidence) { + /* Use custom tokenizer with exception handling */ + rspamd_tokenizer_result_t *custom_res = rspamd_custom_tokenizer_tokenize_with_exceptions( + custom_tok, text, len, exceptions, pool); + + if (custom_res) { + msg_debug_pool("using custom tokenizer %s (confidence: %.2f) for text tokenization", + custom_tok->name, custom_confidence); + + /* Copy custom tokenizer results to output kvec */ + for (unsigned int i = 0; i < kv_size(*custom_res); i++) { + kv_push_safe(rspamd_word_t, *res, kv_A(*custom_res, i), custom_tokenizer_error); + } + + /* Calculate hash if needed */ + if (hash && kv_size(*res) > 0) { + for (unsigned int i = 0; i < kv_size(*res); i++) { + rspamd_word_t *t = &kv_A(*res, i); + if (t->original.len >= sizeof(uint64_t)) { + uint64_t tmp; + memcpy(&tmp, t->original.begin, sizeof(tmp)); + hv = mum_hash_step(hv, tmp); + } + } + *hash = mum_hash_finish(hv); + } + + return res; + } + else { + msg_warn_pool("custom tokenizer %s failed to tokenize text, falling back to default", + custom_tok->name); + } + } } if (G_UNLIKELY(how == RSPAMD_TOKENIZE_RAW || utxt == NULL)) { @@ -343,7 +397,7 @@ rspamd_tokenize_text(const char *text, gsize len, } if (token.original.len > 0 && - rspamd_tokenize_check_limit(decay, word_decay, res->len, + rspamd_tokenize_check_limit(decay, word_decay, kv_size(*res), &hv, &prob, &token, pos - text, len)) { if (!decay) { decay = TRUE; @@ -355,28 +409,28 @@ rspamd_tokenize_text(const char *text, gsize len, } if (long_text_mode) { - if ((res->len + 1) % 16 == 0) { + if ((kv_size(*res) + 1) % 16 == 0) { ev_tstamp now = ev_time(); if (now - start > max_exec_time) { msg_warn_pool_check( "too long time has been spent on tokenization:" - " %.1f ms, limit is %.1f ms; %d words added so far", + " %.1f ms, limit is %.1f ms; %z words added so far", (now - start) * 1e3, max_exec_time * 1e3, - res->len); + kv_size(*res)); goto end; } } } - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, tokenize_error); - if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + if (kv_size(*res) * sizeof(token) > (0x1ull << 30u)) { /* Due to bug in glib ! */ msg_err_pool_check( - "too many words found: %d, stop tokenization to avoid DoS", - res->len); + "too many words found: %z, stop tokenization to avoid DoS", + kv_size(*res)); goto end; } @@ -523,7 +577,7 @@ rspamd_tokenize_text(const char *text, gsize len, } if (token.original.len > 0 && - rspamd_tokenize_check_limit(decay, word_decay, res->len, + rspamd_tokenize_check_limit(decay, word_decay, kv_size(*res), &hv, &prob, &token, p, len)) { if (!decay) { decay = TRUE; @@ -536,15 +590,15 @@ rspamd_tokenize_text(const char *text, gsize len, if (token.original.len > 0) { /* Additional check for number of words */ - if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + if (kv_size(*res) * sizeof(token) > (0x1ull << 30u)) { /* Due to bug in glib ! */ - msg_err("too many words found: %d, stop tokenization to avoid DoS", - res->len); + msg_err("too many words found: %z, stop tokenization to avoid DoS", + kv_size(*res)); goto end; } - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, tokenize_error); } /* Also check for long text mode */ @@ -552,15 +606,15 @@ rspamd_tokenize_text(const char *text, gsize len, /* Check time each 128 words added */ const int words_check_mask = 0x7F; - if ((res->len & words_check_mask) == words_check_mask) { + if ((kv_size(*res) & words_check_mask) == words_check_mask) { ev_tstamp now = ev_time(); if (now - start > max_exec_time) { msg_warn_pool_check( "too long time has been spent on tokenization:" - " %.1f ms, limit is %.1f ms; %d words added so far", + " %.1f ms, limit is %.1f ms; %z words added so far", (now - start) * 1e3, max_exec_time * 1e3, - res->len); + kv_size(*res)); goto end; } @@ -590,8 +644,14 @@ end: } return res; + +tokenize_error: +custom_tokenizer_error: + msg_err_pool("failed to allocate memory for tokenization"); + return res; } + #undef SHIFT_EX static void @@ -625,32 +685,38 @@ rspamd_add_metawords_from_str(const char *beg, gsize len, #endif } + /* Initialize meta_words kvec if not already done */ + if (!task->meta_words.a) { + kv_init(task->meta_words); + } + if (valid_utf) { utext_openUTF8(&utxt, beg, len, &uc_err); - task->meta_words = rspamd_tokenize_text(beg, len, - &utxt, RSPAMD_TOKENIZE_UTF, - task->cfg, NULL, NULL, - task->meta_words, - task->task_pool); + rspamd_tokenize_text(beg, len, + &utxt, RSPAMD_TOKENIZE_UTF, + task->cfg, NULL, NULL, + &task->meta_words, + task->task_pool); utext_close(&utxt); } else { - task->meta_words = rspamd_tokenize_text(beg, len, - NULL, RSPAMD_TOKENIZE_RAW, - task->cfg, NULL, NULL, task->meta_words, - task->task_pool); + rspamd_tokenize_text(beg, len, + NULL, RSPAMD_TOKENIZE_RAW, + task->cfg, NULL, NULL, + &task->meta_words, + task->task_pool); } } void rspamd_tokenize_meta_words(struct rspamd_task *task) { unsigned int i = 0; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; if (MESSAGE_FIELD(task, subject)) { rspamd_add_metawords_from_str(MESSAGE_FIELD(task, subject), @@ -667,7 +733,7 @@ void rspamd_tokenize_meta_words(struct rspamd_task *task) } } - if (task->meta_words != NULL) { + if (task->meta_words.a) { const char *language = NULL; if (MESSAGE_FIELD(task, text_parts) && @@ -680,12 +746,12 @@ void rspamd_tokenize_meta_words(struct rspamd_task *task) } } - rspamd_normalize_words(task->meta_words, task->task_pool); - rspamd_stem_words(task->meta_words, task->task_pool, language, + rspamd_normalize_words(&task->meta_words, task->task_pool); + rspamd_stem_words(&task->meta_words, task->task_pool, language, task->lang_det); - for (i = 0; i < task->meta_words->len; i++) { - tok = &g_array_index(task->meta_words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(task->meta_words); i++) { + tok = &kv_A(task->meta_words, i); tok->flags |= RSPAMD_STAT_TOKEN_FLAG_HEADER; } } @@ -759,7 +825,7 @@ rspamd_ucs32_to_normalised(rspamd_stat_token_t *tok, tok->normalized.begin = dest; } -void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool) +void rspamd_normalize_single_word(rspamd_word_t *tok, rspamd_mempool_t *pool) { UErrorCode uc_err = U_ZERO_ERROR; UConverter *utf8_converter; @@ -858,25 +924,27 @@ void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *po } } -void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool) + +void rspamd_normalize_words(rspamd_words_t *words, rspamd_mempool_t *pool) { - rspamd_stat_token_t *tok; + rspamd_word_t *tok; unsigned int i; - for (i = 0; i < words->len; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); rspamd_normalize_single_word(tok, pool); } } -void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, + +void rspamd_stem_words(rspamd_words_t *words, rspamd_mempool_t *pool, const char *language, struct rspamd_lang_detector *lang_detector) { static GHashTable *stemmers = NULL; struct sb_stemmer *stem = NULL; unsigned int i; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; char *dest; gsize dlen; @@ -909,8 +977,18 @@ void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, stem = NULL; } } - for (i = 0; i < words->len; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); + + /* Skip stemming if token has already been stemmed by custom tokenizer */ + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_STEMMED) { + /* Already stemmed, just check for stop words */ + if (tok->stemmed.len > 0 && lang_detector != NULL && + rspamd_language_detector_is_stop_word(lang_detector, tok->stemmed.begin, tok->stemmed.len)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STOP_WORD; + } + continue; + } if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { if (stem) { @@ -952,4 +1030,4 @@ void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, } } } -}
\ No newline at end of file +} diff --git a/src/libstat/tokenizers/tokenizers.h b/src/libstat/tokenizers/tokenizers.h index d4a8824a8..bb0bb54e2 100644 --- a/src/libstat/tokenizers/tokenizers.h +++ b/src/libstat/tokenizers/tokenizers.h @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "fstring.h" #include "rspamd.h" #include "stat_api.h" +#include "libserver/word.h" #include <unicode/utext.h> @@ -43,7 +44,7 @@ struct rspamd_stat_tokenizer { int (*tokenize_func)(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result); @@ -59,20 +60,20 @@ enum rspamd_tokenize_type { int token_node_compare_func(gconstpointer a, gconstpointer b); -/* Tokenize text into array of words (rspamd_stat_token_t type) */ -GArray *rspamd_tokenize_text(const char *text, gsize len, - const UText *utxt, - enum rspamd_tokenize_type how, - struct rspamd_config *cfg, - GList *exceptions, - uint64_t *hash, - GArray *cur_words, - rspamd_mempool_t *pool); +/* Tokenize text into kvec of words (rspamd_word_t type) */ +rspamd_words_t *rspamd_tokenize_text(const char *text, gsize len, + const UText *utxt, + enum rspamd_tokenize_type how, + struct rspamd_config *cfg, + GList *exceptions, + uint64_t *hash, + rspamd_words_t *output_kvec, + rspamd_mempool_t *pool); /* OSB tokenize function */ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result); @@ -83,11 +84,11 @@ gpointer rspamd_tokenizer_osb_get_config(rspamd_mempool_t *pool, struct rspamd_lang_detector; -void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool); +void rspamd_normalize_single_word(rspamd_word_t *tok, rspamd_mempool_t *pool); -void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool); - -void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, +/* Word processing functions */ +void rspamd_normalize_words(rspamd_words_t *words, rspamd_mempool_t *pool); +void rspamd_stem_words(rspamd_words_t *words, rspamd_mempool_t *pool, const char *language, struct rspamd_lang_detector *lang_detector); |