diff options
Diffstat (limited to 'src/libstat/backends')
-rw-r--r-- | src/libstat/backends/cdb_backend.cxx | 13 | ||||
-rw-r--r-- | src/libstat/backends/mmaped_file.c | 10 | ||||
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 537 | ||||
-rw-r--r-- | src/libstat/backends/sqlite3_backend.c | 7 |
4 files changed, 460 insertions, 107 deletions
diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx index 0f55a725c..f6ca9c12d 100644 --- a/src/libstat/backends/cdb_backend.cxx +++ b/src/libstat/backends/cdb_backend.cxx @@ -393,7 +393,6 @@ rspamd_cdb_process_tokens(struct rspamd_task *task, gpointer runtime) { auto *cdbp = CDB_FROM_RAW(runtime); - bool seen_values = false; for (auto i = 0u; i < tokens->len; i++) { rspamd_token_t *tok; @@ -403,21 +402,13 @@ rspamd_cdb_process_tokens(struct rspamd_task *task, if (res) { tok->values[id] = res.value(); - seen_values = true; } else { tok->values[id] = 0; } } - if (seen_values) { - if (cdbp->is_spam()) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } - } + /* No longer need to set flags - multi-class handles missing data naturally */ return true; } @@ -488,4 +479,4 @@ void rspamd_cdb_close(gpointer ctx) { auto *cdbp = CDB_FROM_RAW(ctx); delete cdbp; -}
\ No newline at end of file +} diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c index 4430bb9a4..a6423a1e6 100644 --- a/src/libstat/backends/mmaped_file.c +++ b/src/libstat/backends/mmaped_file.c @@ -85,8 +85,7 @@ typedef struct { #define RSPAMD_STATFILE_VERSION \ { \ - '1', '2' \ - } + '1', '2'} #define BACKUP_SUFFIX ".old" static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool, @@ -958,12 +957,7 @@ rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens, tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2); } - if (mf->cf->is_spam) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } + /* No longer need to set flags - multi-class handles missing data naturally */ return TRUE; } diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 7137904e9..302778bcb 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -22,6 +22,7 @@ #include "contrib/fmt/include/fmt/base.h" #include "libutil/cxx/error.hxx" +#include <map> #include <string> #include <cstdint> @@ -121,9 +122,9 @@ public: } static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded, - bool is_spam) -> std::optional<redis_stat_runtime<T> *> + const char *class_label) -> std::optional<redis_stat_runtime<T> *> { - auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label); auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str()); if (res) { @@ -147,9 +148,15 @@ public: rspamd_token_t *tok; if (!results) { + msg_debug_bayes("process_tokens: no results available for statfile id=%d", id); return false; } + if (results->size() > 0) { + msg_debug_bayes("processing %uz tokens for statfile id=%d, class=%s", + results->size(), id, stcf->class_name ? stcf->class_name : "unknown"); + } + for (auto [idx, val]: *results) { tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1); tok->values[id] = val; @@ -158,12 +165,14 @@ public: return true; } - auto save_in_mempool(bool is_spam) const + auto save_in_mempool(const char *class_label) const { - auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto var_name = + rspamd_mempool_strdup(task->task_pool, + fmt::format("{}_{}", redis_object_expanded, class_label).c_str()); /* We do not set destructor for the variable, as it should be already added on creation */ - rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr); - msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str()); + rspamd_mempool_set_variable(task->task_pool, var_name, (gpointer) this, nullptr); + msg_debug_bayes("saved runtime in mempool at %s", var_name); } }; @@ -178,6 +187,26 @@ rspamd_redis_stat_quark(void) } /* + * Get the class label for a statfile (for multi-class support) + */ +static const char * +get_class_label(struct rspamd_statfile_config *stcf) +{ + /* Try to get the label from the classifier config first */ + if (stcf->clcf && stcf->clcf->class_labels && stcf->class_name) { + const char *label = rspamd_config_get_class_label(stcf->clcf, stcf->class_name); + if (label) { + return label; + } + /* If no label mapping found, use class name directly */ + return stcf->class_name; + } + + /* Fallback to legacy binary classification */ + return stcf->is_spam ? "S" : "H"; +} + +/* * Non-static for lua unit testing */ gsize rspamd_redis_expand_object(const char *pattern, @@ -235,6 +264,11 @@ gsize rspamd_redis_expand_object(const char *pattern, if (rcpt) { rspamd_mempool_set_variable(task->task_pool, "stat_user", (gpointer) rcpt, nullptr); + msg_debug_bayes("redis expansion: found recipient '%s'", rcpt); + } + else { + msg_debug_bayes("redis expansion: no recipient found (deliver_to=%s)", + task->deliver_to ? task->deliver_to : "null"); } } @@ -448,6 +482,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, users_enabled = ucl_object_lookup_any(classifier_obj, "per_user", "users_enabled", nullptr); + msg_debug_bayes_cfg("per-user lookup: users_enabled=%p", users_enabled); if (users_enabled != nullptr) { if (ucl_object_type(users_enabled) == UCL_BOOLEAN) { backend->enable_users = ucl_object_toboolean(users_enabled); @@ -485,9 +520,16 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, /* Default non-users statistics */ if (backend->enable_users || backend->cbref_user != -1) { backend->redis_object = REDIS_DEFAULT_USERS_OBJECT; + msg_debug_bayes_cfg("using per-user Redis pattern: %s (enable_users=%s, cbref_user=%d)", + backend->redis_object, backend->enable_users ? "true" : "false", + backend->cbref_user); } else { backend->redis_object = REDIS_DEFAULT_OBJECT; + msg_debug_bayes_cfg("using default Redis pattern: %s (enable_users=%s, cbref_user=%d)", + backend->redis_object, + backend->enable_users ? "true" : "false", + backend->cbref_user); } } else { @@ -541,7 +583,7 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx, ucl_object_push_lua(L, st->classifier->cfg->opts, false); ucl_object_push_lua(L, st->stcf->opts, false); lua_pushstring(L, backend->stcf->symbol); - lua_pushboolean(L, backend->stcf->is_spam); + lua_pushstring(L, get_class_label(backend->stcf)); /* Pass class label instead of boolean */ /* Push event loop if there is one available (e.g. we are not in rspamadm mode) */ if (ctx->event_loop) { @@ -606,11 +648,20 @@ rspamd_redis_runtime(struct rspamd_task *task, stcf->symbol); return nullptr; } + else { + msg_debug_bayes("redis object expanded: pattern='%s' -> expanded='%s' (learn=%s, symbol=%s)", + ctx->redis_object ? ctx->redis_object : "default", + object_expanded, + learn ? "true" : "false", + stcf->symbol); + } + + const char *class_label = get_class_label(stcf); /* Look for the cached results */ if (!learn) { auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task, - object_expanded, stcf->is_spam); + object_expanded, class_label); if (maybe_existing) { auto *rt = maybe_existing.value(); @@ -624,24 +675,62 @@ rspamd_redis_runtime(struct rspamd_task *task, /* No cached result (or learn), create new one */ auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); - if (!learn) { - /* - * For check, we also need to create the opposite class runtime to avoid - * double call for Redis scripts. - * This runtime will be filled later. - */ - auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, - object_expanded, - !stcf->is_spam); - - if (!maybe_opposite_rt) { - auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); - opposite_rt->save_in_mempool(!stcf->is_spam); - opposite_rt->need_redis_call = false; + /* Find the statfile ID for the main runtime */ + int main_id = _id; /* Use the passed _id parameter */ + rt->id = main_id; + rt->stcf = stcf; + + /* For classification, create runtimes for all other statfiles to avoid multiple Redis calls */ + if (!learn && stcf->clcf && stcf->clcf->statfiles) { + GList *cur = stcf->clcf->statfiles; + + while (cur) { + auto *other_stcf = (struct rspamd_statfile_config *) cur->data; + const char *other_label = get_class_label(other_stcf); + + /* Find the statfile ID by searching through all statfiles */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + int other_id = -1; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *st = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (st->stcf == other_stcf) { + other_id = st->id; + msg_debug_bayes("found statfile mapping: %s (class=%s) → id=%d", + st->stcf->symbol, other_label, other_id); + break; + } + } + + if (other_id == -1) { + msg_debug_bayes("statfile not found for class %s, skipping", other_label); + /* Skip if statfile not found */ + cur = g_list_next(cur); + continue; + } + + if (other_stcf == stcf) { + /* This is the main statfile, use the main runtime */ + rt->save_in_mempool(other_label); + msg_debug_bayes("main runtime: statfile %s (class=%s) → id=%d", + stcf->symbol, other_label, rt->id); + } + else { + /* Create additional runtime for other statfile */ + auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + other_rt->id = other_id; + other_rt->stcf = other_stcf; + other_rt->need_redis_call = false; + other_rt->save_in_mempool(other_label); + msg_debug_bayes("additional runtime: statfile %s (class=%s) → id=%d", + other_stcf->symbol, other_label, other_id); + } + + cur = g_list_next(cur); } } - - rt->save_in_mempool(stcf->is_spam); + else { + rt->save_in_mempool(class_label); + } return rt; } @@ -816,77 +905,306 @@ rspamd_redis_classified(lua_State *L) if (rt == nullptr) { msg_err_task("internal error: cannot find runtime for cookie %s", cookie); - return 0; } bool result = lua_toboolean(L, 2); if (result) { - /* Indexes: - * 3 - learned_ham (int) - * 4 - learned_spam (int) - * 5 - ham_tokens (pair<int, int>) - * 6 - spam_tokens (pair<int, int>) - */ - - /* - * We need to fill our runtime AND the opposite runtime - */ - auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) { - rt->learned = learned; - redis_stat_runtime<float>::result_type *res; - - res = new redis_stat_runtime<float>::result_type(); - - for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) { - lua_rawgeti(L, -1, 1); - auto idx = lua_tointeger(L, -1); - lua_pop(L, 1); - - lua_rawgeti(L, -1, 2); - auto value = lua_tonumber(L, -1); - lua_pop(L, 1); - - res->emplace_back(idx, value); - } - - rt->set_results(res); - }; - - auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task, - rt->redis_object_expanded, - !rt->stcf->is_spam); + /* Check we have enough arguments and the result data is a table */ + if (lua_gettop(L) < 3 || !lua_istable(L, 3)) { + msg_err_task("internal error: expected table result from Redis script, got %s", + lua_typename(L, lua_type(L, 3))); + rt->err = rspamd::util::error("invalid Redis script result format", 500); + return 0; + } - if (!opposite_rt_maybe) { - msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); + /* Redis returns [learned_counts_array, token_results_array] + * Both ordered the same way as statfiles in classifier */ + size_t result_len = rspamd_lua_table_size(L, 3); + msg_debug_bayes("Redis result array length: %uz", result_len); + if (result_len != 2) { + msg_err_task("internal error: expected 2-element result from Redis script, got %uz", result_len); + rt->err = rspamd::util::error("invalid Redis script result format", 500); return 0; } - if (rt->stcf->is_spam) { - filler_func(rt, L, lua_tointeger(L, 4), 6); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + /* Get learned_counts_array and token_results_array */ + lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */ + lua_rawgeti(L, 3, 2); /* token_results -> position 5 */ + + /* First, process learned_counts */ + if (lua_istable(L, 4) && rt->stcf->clcf) { + if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: use class_names order */ + for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); + + /* Find statfile with this class name */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + const char *class_label = get_class_label(stcf); + + /* Get the runtime for this statfile */ + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + + /* Extract learned count using class index (1-based for Lua) */ + lua_rawgeti(L, 4, class_idx + 1); + if (lua_isnumber(L, -1)) { + statfile_rt->learned = lua_tointeger(L, -1); + msg_debug_bayes("set learned count for class %s (label %s): %L", + class_name, class_label, statfile_rt->learned); + } + lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */ + } + break; /* Found the statfile for this class */ + } + cur = g_list_next(cur); + } + } + } + else { + /* Binary classification: process statfiles in order */ + GList *cur = rt->stcf->clcf->statfiles; + unsigned int statfile_idx = 0; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); + + /* Get the runtime for this statfile */ + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + + /* Extract learned count using statfile index (1-based for Lua) */ + lua_rawgeti(L, 4, statfile_idx + 1); + if (lua_isnumber(L, -1)) { + statfile_rt->learned = lua_tointeger(L, -1); + msg_debug_bayes("set learned count for statfile %s (label %s): %L", + stcf->symbol, class_label, statfile_rt->learned); + } + lua_pop(L, 1); /* Pop learned_counts[statfile_idx + 1] */ + } + cur = g_list_next(cur); + statfile_idx++; + } + } } - else { - filler_func(rt, L, lua_tointeger(L, 3), 5); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); + + /* Process token results */ + if (lua_istable(L, 5) && rt->stcf->clcf) { + if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: use class_names order */ + for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx); + + /* Find statfile with this class name */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + const char *class_label = get_class_label(stcf); + + /* Find the statfile ID */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + struct rspamd_statfile *st = nullptr; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (candidate->stcf == stcf) { + st = candidate; + break; + } + } + + if (!st) { + msg_debug_bayes("statfile not found for class %s, skipping", class_name); + break; + } + + /* Get or create runtime for this statfile */ + auto *statfile_rt = rt; /* Use current runtime if it matches */ + if (stcf != rt->stcf) { + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + statfile_rt = maybe_rt.value(); + } + else { + msg_debug_bayes("runtime not found for class %s, skipping", class_label); + break; + } + } + + /* Ensure correct statfile ID assignment */ + statfile_rt->id = st->id; + + /* Process token results using class index (1-based for Lua) */ + lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */ + if (lua_istable(L, -1)) { + /* Parse token results into statfile runtime */ + auto *res = new std::vector<std::pair<int, float>>(); + + lua_pushnil(L); /* First key for iteration */ + while (lua_next(L, -2) != 0) { + if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { + lua_rawgeti(L, -1, 1); /* token_index */ + lua_rawgeti(L, -2, 2); /* token_count */ + + if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { + int token_idx = lua_tointeger(L, -2); + float token_count = lua_tonumber(L, -1); + res->emplace_back(token_idx, token_count); + } + + lua_pop(L, 2); /* Pop token_index and token_count */ + } + lua_pop(L, 1); /* Pop value, keep key for next iteration */ + } + + statfile_rt->set_results(res); + } + lua_pop(L, 1); /* Pop token_results[class_idx + 1] */ + break; /* Found the statfile for this class */ + } + cur = g_list_next(cur); + } + } + } + else { + /* Binary classification: process statfiles in order */ + GList *cur = rt->stcf->clcf->statfiles; + unsigned int statfile_idx = 0; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); + + /* Find the statfile ID */ + struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx(); + struct rspamd_statfile *st = nullptr; + for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) { + struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i); + if (candidate->stcf == stcf) { + st = candidate; + break; + } + } + + if (!st) { + msg_debug_bayes("statfile not found for %s, skipping", stcf->symbol); + cur = g_list_next(cur); + statfile_idx++; + continue; + } + + /* Get or create runtime for this statfile */ + auto *statfile_rt = rt; /* Use current runtime if it matches */ + if (stcf != rt->stcf) { + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + statfile_rt = maybe_rt.value(); + } + else { + msg_debug_bayes("runtime not found for %s, skipping", class_label); + cur = g_list_next(cur); + statfile_idx++; + continue; + } + } + + /* Ensure correct statfile ID assignment */ + statfile_rt->id = st->id; + + /* Process token results using statfile index (1-based for Lua) */ + lua_rawgeti(L, 5, statfile_idx + 1); /* Get token_results[statfile_idx + 1] */ + if (lua_istable(L, -1)) { + /* Parse token results into statfile runtime */ + auto *res = new std::vector<std::pair<int, float>>(); + + lua_pushnil(L); /* First key for iteration */ + while (lua_next(L, -2) != 0) { + if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) { + lua_rawgeti(L, -1, 1); /* token_index */ + lua_rawgeti(L, -2, 2); /* token_count */ + + if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) { + int token_idx = lua_tointeger(L, -2); + float token_count = lua_tonumber(L, -1); + res->emplace_back(token_idx, token_count); + } + + lua_pop(L, 2); /* Pop token_index and token_count */ + } + lua_pop(L, 1); /* Pop value, keep key for next iteration */ + } + + statfile_rt->set_results(res); + msg_debug_bayes("set %uz token results for statfile %s (label %s, id=%d)", + res->size(), stcf->symbol, class_label, st->id); + } + lua_pop(L, 1); /* Pop token_results[statfile_idx + 1] */ + + cur = g_list_next(cur); + statfile_idx++; + } + } } - /* Mark task as being processed */ - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + /* Clean up stack */ + lua_pop(L, 2); /* Pop learned_counts and token_results */ - /* Process all tokens */ + /* Process tokens for all runtimes */ g_assert(rt->tokens != nullptr); - rt->process_tokens(rt->tokens); - opposite_rt_maybe.value()->process_tokens(rt->tokens); + + /* Process tokens for all statfiles */ + if (rt->stcf->clcf && rt->stcf->clcf->statfiles) { + GList *cur = rt->stcf->clcf->statfiles; + + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + const char *class_label = get_class_label(stcf); + + auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + class_label); + if (maybe_rt) { + auto *statfile_rt = maybe_rt.value(); + statfile_rt->process_tokens(rt->tokens); + } + + cur = g_list_next(cur); + } + } + else { + /* Fallback: just process the main runtime */ + rt->process_tokens(rt->tokens); + } } else { /* Error message is on index 3 */ - const auto *err_msg = lua_tostring(L, 3); - rt->err = rspamd::util::error(err_msg, 500); - msg_err_task("cannot classify task: %s", - err_msg); + const char *err_msg = nullptr; + if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) { + err_msg = lua_tostring(L, 3); + } + if (err_msg) { + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot classify task: %s", err_msg); + } + else { + rt->err = rspamd::util::error("unknown Redis script error", 500); + msg_err_task("cannot classify task: unknown Redis script error"); + } } return 0; @@ -929,7 +1247,42 @@ rspamd_redis_process_tokens(struct rspamd_task *task, rspamd_lua_task_push(L, task); lua_pushstring(L, rt->redis_object_expanded); lua_pushinteger(L, id); - lua_pushboolean(L, rt->stcf->is_spam); + + /* Send all class labels for multi-class support */ + if (rt->stcf->clcf && rt->stcf->clcf->class_names && + rt->stcf->clcf->class_names->len > 0) { + /* Multi-class: send array of class labels in deterministic order */ + lua_createtable(L, rt->stcf->clcf->class_names->len, 0); + for (unsigned int i = 0; i < rt->stcf->clcf->class_names->len; i++) { + const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, i); + const char *class_label = nullptr; + + /* Find the class label for this class name from any statfile with this class */ + GList *cur = rt->stcf->clcf->statfiles; + while (cur) { + auto *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) { + class_label = get_class_label(stcf); + break; + } + cur = g_list_next(cur); + } + + if (class_label) { + lua_pushstring(L, class_label); + lua_rawseti(L, -2, i + 1); /* Lua arrays are 1-indexed */ + } + } + } + else { + /* Binary classification: send both spam and ham labels for optimization */ + lua_createtable(L, 2, 0); + lua_pushstring(L, "H"); /* ham */ + lua_rawseti(L, -2, 1); + lua_pushstring(L, "S"); /* spam */ + lua_rawseti(L, -2, 2); + } + lua_new_text(L, tokens_buf, tokens_len, false); /* Store rt in random cookie */ @@ -979,13 +1332,31 @@ rspamd_redis_learned(lua_State *L) bool result = lua_toboolean(L, 2); if (result) { - /* TODO: write it */ + /* Learning successful - no complex data to process like in classification */ + msg_debug_bayes("learned tokens successfully in Redis for symbol %s, class %s", + rt->stcf->symbol, get_class_label(rt->stcf)); + + /* Clear any previous error state */ + rt->err = std::nullopt; + + /* Learning operations don't return data structures to process, + * they just update Redis state. Success means the Redis script + * completed without errors. */ } else { /* Error message is on index 3 */ - const auto *err_msg = lua_tostring(L, 3); - rt->err = rspamd::util::error(err_msg, 500); - msg_err_task("cannot learn task: %s", err_msg); + const char *err_msg = nullptr; + if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) { + err_msg = lua_tostring(L, 3); + } + if (err_msg) { + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot learn task: %s", err_msg); + } + else { + rt->err = rspamd::util::error("unknown Redis script error", 500); + msg_err_task("cannot learn task: unknown Redis script error"); + } } return 0; @@ -1028,7 +1399,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, rspamd_lua_task_push(L, task); lua_pushstring(L, rt->redis_object_expanded); lua_pushinteger(L, id); - lua_pushboolean(L, rt->stcf->is_spam); + lua_pushstring(L, get_class_label(rt->stcf)); /* Pass class label instead of boolean */ lua_pushstring(L, rt->stcf->symbol); /* Detect unlearn */ @@ -1056,6 +1427,8 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, lua_new_text(L, text_tokens_buf, text_tokens_len, false); } + msg_debug_bayes("called lua learn script for %s (cookie=%s)", rt->stcf->symbol, cookie); + if (lua_pcall(L, nargs, 0, err_idx) != 0) { msg_err_task("call to script failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); diff --git a/src/libstat/backends/sqlite3_backend.c b/src/libstat/backends/sqlite3_backend.c index 973dc30a7..8f29a3b4e 100644 --- a/src/libstat/backends/sqlite3_backend.c +++ b/src/libstat/backends/sqlite3_backend.c @@ -589,12 +589,7 @@ rspamd_sqlite3_process_tokens(struct rspamd_task *task, } } - if (rt->cf->is_spam) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } + /* No longer need to set flags - multi-class handles missing data naturally */ } |