aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat/backends/redis_backend.cxx
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat/backends/redis_backend.cxx')
-rw-r--r--src/libstat/backends/redis_backend.cxx537
1 files changed, 455 insertions, 82 deletions
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index 7137904e9..302778bcb 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -22,6 +22,7 @@
#include "contrib/fmt/include/fmt/base.h"
#include "libutil/cxx/error.hxx"
+#include <map>
#include <string>
#include <cstdint>
@@ -121,9 +122,9 @@ public:
}
static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
- bool is_spam) -> std::optional<redis_stat_runtime<T> *>
+ const char *class_label) -> std::optional<redis_stat_runtime<T> *>
{
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, class_label);
auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
if (res) {
@@ -147,9 +148,15 @@ public:
rspamd_token_t *tok;
if (!results) {
+ msg_debug_bayes("process_tokens: no results available for statfile id=%d", id);
return false;
}
+ if (results->size() > 0) {
+ msg_debug_bayes("processing %uz tokens for statfile id=%d, class=%s",
+ results->size(), id, stcf->class_name ? stcf->class_name : "unknown");
+ }
+
for (auto [idx, val]: *results) {
tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
tok->values[id] = val;
@@ -158,12 +165,14 @@ public:
return true;
}
- auto save_in_mempool(bool is_spam) const
+ auto save_in_mempool(const char *class_label) const
{
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto var_name =
+ rspamd_mempool_strdup(task->task_pool,
+ fmt::format("{}_{}", redis_object_expanded, class_label).c_str());
/* We do not set destructor for the variable, as it should be already added on creation */
- rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
- msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
+ rspamd_mempool_set_variable(task->task_pool, var_name, (gpointer) this, nullptr);
+ msg_debug_bayes("saved runtime in mempool at %s", var_name);
}
};
@@ -178,6 +187,26 @@ rspamd_redis_stat_quark(void)
}
/*
+ * Get the class label for a statfile (for multi-class support)
+ */
+static const char *
+get_class_label(struct rspamd_statfile_config *stcf)
+{
+ /* Try to get the label from the classifier config first */
+ if (stcf->clcf && stcf->clcf->class_labels && stcf->class_name) {
+ const char *label = rspamd_config_get_class_label(stcf->clcf, stcf->class_name);
+ if (label) {
+ return label;
+ }
+ /* If no label mapping found, use class name directly */
+ return stcf->class_name;
+ }
+
+ /* Fallback to legacy binary classification */
+ return stcf->is_spam ? "S" : "H";
+}
+
+/*
* Non-static for lua unit testing
*/
gsize rspamd_redis_expand_object(const char *pattern,
@@ -235,6 +264,11 @@ gsize rspamd_redis_expand_object(const char *pattern,
if (rcpt) {
rspamd_mempool_set_variable(task->task_pool, "stat_user",
(gpointer) rcpt, nullptr);
+ msg_debug_bayes("redis expansion: found recipient '%s'", rcpt);
+ }
+ else {
+ msg_debug_bayes("redis expansion: no recipient found (deliver_to=%s)",
+ task->deliver_to ? task->deliver_to : "null");
}
}
@@ -448,6 +482,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
users_enabled = ucl_object_lookup_any(classifier_obj, "per_user",
"users_enabled", nullptr);
+ msg_debug_bayes_cfg("per-user lookup: users_enabled=%p", users_enabled);
if (users_enabled != nullptr) {
if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
backend->enable_users = ucl_object_toboolean(users_enabled);
@@ -485,9 +520,16 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
/* Default non-users statistics */
if (backend->enable_users || backend->cbref_user != -1) {
backend->redis_object = REDIS_DEFAULT_USERS_OBJECT;
+ msg_debug_bayes_cfg("using per-user Redis pattern: %s (enable_users=%s, cbref_user=%d)",
+ backend->redis_object, backend->enable_users ? "true" : "false",
+ backend->cbref_user);
}
else {
backend->redis_object = REDIS_DEFAULT_OBJECT;
+ msg_debug_bayes_cfg("using default Redis pattern: %s (enable_users=%s, cbref_user=%d)",
+ backend->redis_object,
+ backend->enable_users ? "true" : "false",
+ backend->cbref_user);
}
}
else {
@@ -541,7 +583,7 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx,
ucl_object_push_lua(L, st->classifier->cfg->opts, false);
ucl_object_push_lua(L, st->stcf->opts, false);
lua_pushstring(L, backend->stcf->symbol);
- lua_pushboolean(L, backend->stcf->is_spam);
+ lua_pushstring(L, get_class_label(backend->stcf)); /* Pass class label instead of boolean */
/* Push event loop if there is one available (e.g. we are not in rspamadm mode) */
if (ctx->event_loop) {
@@ -606,11 +648,20 @@ rspamd_redis_runtime(struct rspamd_task *task,
stcf->symbol);
return nullptr;
}
+ else {
+ msg_debug_bayes("redis object expanded: pattern='%s' -> expanded='%s' (learn=%s, symbol=%s)",
+ ctx->redis_object ? ctx->redis_object : "default",
+ object_expanded,
+ learn ? "true" : "false",
+ stcf->symbol);
+ }
+
+ const char *class_label = get_class_label(stcf);
/* Look for the cached results */
if (!learn) {
auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded, stcf->is_spam);
+ object_expanded, class_label);
if (maybe_existing) {
auto *rt = maybe_existing.value();
@@ -624,24 +675,62 @@ rspamd_redis_runtime(struct rspamd_task *task,
/* No cached result (or learn), create new one */
auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- if (!learn) {
- /*
- * For check, we also need to create the opposite class runtime to avoid
- * double call for Redis scripts.
- * This runtime will be filled later.
- */
- auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded,
- !stcf->is_spam);
-
- if (!maybe_opposite_rt) {
- auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- opposite_rt->save_in_mempool(!stcf->is_spam);
- opposite_rt->need_redis_call = false;
+ /* Find the statfile ID for the main runtime */
+ int main_id = _id; /* Use the passed _id parameter */
+ rt->id = main_id;
+ rt->stcf = stcf;
+
+ /* For classification, create runtimes for all other statfiles to avoid multiple Redis calls */
+ if (!learn && stcf->clcf && stcf->clcf->statfiles) {
+ GList *cur = stcf->clcf->statfiles;
+
+ while (cur) {
+ auto *other_stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *other_label = get_class_label(other_stcf);
+
+ /* Find the statfile ID by searching through all statfiles */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ int other_id = -1;
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *st = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (st->stcf == other_stcf) {
+ other_id = st->id;
+ msg_debug_bayes("found statfile mapping: %s (class=%s) → id=%d",
+ st->stcf->symbol, other_label, other_id);
+ break;
+ }
+ }
+
+ if (other_id == -1) {
+ msg_debug_bayes("statfile not found for class %s, skipping", other_label);
+ /* Skip if statfile not found */
+ cur = g_list_next(cur);
+ continue;
+ }
+
+ if (other_stcf == stcf) {
+ /* This is the main statfile, use the main runtime */
+ rt->save_in_mempool(other_label);
+ msg_debug_bayes("main runtime: statfile %s (class=%s) → id=%d",
+ stcf->symbol, other_label, rt->id);
+ }
+ else {
+ /* Create additional runtime for other statfile */
+ auto *other_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ other_rt->id = other_id;
+ other_rt->stcf = other_stcf;
+ other_rt->need_redis_call = false;
+ other_rt->save_in_mempool(other_label);
+ msg_debug_bayes("additional runtime: statfile %s (class=%s) → id=%d",
+ other_stcf->symbol, other_label, other_id);
+ }
+
+ cur = g_list_next(cur);
}
}
-
- rt->save_in_mempool(stcf->is_spam);
+ else {
+ rt->save_in_mempool(class_label);
+ }
return rt;
}
@@ -816,77 +905,306 @@ rspamd_redis_classified(lua_State *L)
if (rt == nullptr) {
msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
-
return 0;
}
bool result = lua_toboolean(L, 2);
if (result) {
- /* Indexes:
- * 3 - learned_ham (int)
- * 4 - learned_spam (int)
- * 5 - ham_tokens (pair<int, int>)
- * 6 - spam_tokens (pair<int, int>)
- */
-
- /*
- * We need to fill our runtime AND the opposite runtime
- */
- auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
- rt->learned = learned;
- redis_stat_runtime<float>::result_type *res;
-
- res = new redis_stat_runtime<float>::result_type();
-
- for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
- lua_rawgeti(L, -1, 1);
- auto idx = lua_tointeger(L, -1);
- lua_pop(L, 1);
-
- lua_rawgeti(L, -1, 2);
- auto value = lua_tonumber(L, -1);
- lua_pop(L, 1);
-
- res->emplace_back(idx, value);
- }
-
- rt->set_results(res);
- };
-
- auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- rt->redis_object_expanded,
- !rt->stcf->is_spam);
+ /* Check we have enough arguments and the result data is a table */
+ if (lua_gettop(L) < 3 || !lua_istable(L, 3)) {
+ msg_err_task("internal error: expected table result from Redis script, got %s",
+ lua_typename(L, lua_type(L, 3)));
+ rt->err = rspamd::util::error("invalid Redis script result format", 500);
+ return 0;
+ }
- if (!opposite_rt_maybe) {
- msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+ /* Redis returns [learned_counts_array, token_results_array]
+ * Both ordered the same way as statfiles in classifier */
+ size_t result_len = rspamd_lua_table_size(L, 3);
+ msg_debug_bayes("Redis result array length: %uz", result_len);
+ if (result_len != 2) {
+ msg_err_task("internal error: expected 2-element result from Redis script, got %uz", result_len);
+ rt->err = rspamd::util::error("invalid Redis script result format", 500);
return 0;
}
- if (rt->stcf->is_spam) {
- filler_func(rt, L, lua_tointeger(L, 4), 6);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ /* Get learned_counts_array and token_results_array */
+ lua_rawgeti(L, 3, 1); /* learned_counts -> position 4 */
+ lua_rawgeti(L, 3, 2); /* token_results -> position 5 */
+
+ /* First, process learned_counts */
+ if (lua_istable(L, 4) && rt->stcf->clcf) {
+ if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) {
+ /* Multi-class: use class_names order */
+ for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
+ const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
+
+ /* Find statfile with this class name */
+ GList *cur = rt->stcf->clcf->statfiles;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+ const char *class_label = get_class_label(stcf);
+
+ /* Get the runtime for this statfile */
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ auto *statfile_rt = maybe_rt.value();
+
+ /* Extract learned count using class index (1-based for Lua) */
+ lua_rawgeti(L, 4, class_idx + 1);
+ if (lua_isnumber(L, -1)) {
+ statfile_rt->learned = lua_tointeger(L, -1);
+ msg_debug_bayes("set learned count for class %s (label %s): %L",
+ class_name, class_label, statfile_rt->learned);
+ }
+ lua_pop(L, 1); /* Pop learned_counts[class_idx + 1] */
+ }
+ break; /* Found the statfile for this class */
+ }
+ cur = g_list_next(cur);
+ }
+ }
+ }
+ else {
+ /* Binary classification: process statfiles in order */
+ GList *cur = rt->stcf->clcf->statfiles;
+ unsigned int statfile_idx = 0;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *class_label = get_class_label(stcf);
+
+ /* Get the runtime for this statfile */
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(rt->task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ auto *statfile_rt = maybe_rt.value();
+
+ /* Extract learned count using statfile index (1-based for Lua) */
+ lua_rawgeti(L, 4, statfile_idx + 1);
+ if (lua_isnumber(L, -1)) {
+ statfile_rt->learned = lua_tointeger(L, -1);
+ msg_debug_bayes("set learned count for statfile %s (label %s): %L",
+ stcf->symbol, class_label, statfile_rt->learned);
+ }
+ lua_pop(L, 1); /* Pop learned_counts[statfile_idx + 1] */
+ }
+ cur = g_list_next(cur);
+ statfile_idx++;
+ }
+ }
}
- else {
- filler_func(rt, L, lua_tointeger(L, 3), 5);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+
+ /* Process token results */
+ if (lua_istable(L, 5) && rt->stcf->clcf) {
+ if (rt->stcf->clcf->class_names && rt->stcf->clcf->class_names->len > 0) {
+ /* Multi-class: use class_names order */
+ for (unsigned int class_idx = 0; class_idx < rt->stcf->clcf->class_names->len; class_idx++) {
+ const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, class_idx);
+
+ /* Find statfile with this class name */
+ GList *cur = rt->stcf->clcf->statfiles;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+ const char *class_label = get_class_label(stcf);
+
+ /* Find the statfile ID */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ struct rspamd_statfile *st = nullptr;
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (candidate->stcf == stcf) {
+ st = candidate;
+ break;
+ }
+ }
+
+ if (!st) {
+ msg_debug_bayes("statfile not found for class %s, skipping", class_name);
+ break;
+ }
+
+ /* Get or create runtime for this statfile */
+ auto *statfile_rt = rt; /* Use current runtime if it matches */
+ if (stcf != rt->stcf) {
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ statfile_rt = maybe_rt.value();
+ }
+ else {
+ msg_debug_bayes("runtime not found for class %s, skipping", class_label);
+ break;
+ }
+ }
+
+ /* Ensure correct statfile ID assignment */
+ statfile_rt->id = st->id;
+
+ /* Process token results using class index (1-based for Lua) */
+ lua_rawgeti(L, 5, class_idx + 1); /* Get token_results[class_idx + 1] */
+ if (lua_istable(L, -1)) {
+ /* Parse token results into statfile runtime */
+ auto *res = new std::vector<std::pair<int, float>>();
+
+ lua_pushnil(L); /* First key for iteration */
+ while (lua_next(L, -2) != 0) {
+ if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+ lua_rawgeti(L, -1, 1); /* token_index */
+ lua_rawgeti(L, -2, 2); /* token_count */
+
+ if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+ int token_idx = lua_tointeger(L, -2);
+ float token_count = lua_tonumber(L, -1);
+ res->emplace_back(token_idx, token_count);
+ }
+
+ lua_pop(L, 2); /* Pop token_index and token_count */
+ }
+ lua_pop(L, 1); /* Pop value, keep key for next iteration */
+ }
+
+ statfile_rt->set_results(res);
+ }
+ lua_pop(L, 1); /* Pop token_results[class_idx + 1] */
+ break; /* Found the statfile for this class */
+ }
+ cur = g_list_next(cur);
+ }
+ }
+ }
+ else {
+ /* Binary classification: process statfiles in order */
+ GList *cur = rt->stcf->clcf->statfiles;
+ unsigned int statfile_idx = 0;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *class_label = get_class_label(stcf);
+
+ /* Find the statfile ID */
+ struct rspamd_stat_ctx *st_ctx = rspamd_stat_get_ctx();
+ struct rspamd_statfile *st = nullptr;
+ for (unsigned int i = 0; i < st_ctx->statfiles->len; i++) {
+ struct rspamd_statfile *candidate = (struct rspamd_statfile *) g_ptr_array_index(st_ctx->statfiles, i);
+ if (candidate->stcf == stcf) {
+ st = candidate;
+ break;
+ }
+ }
+
+ if (!st) {
+ msg_debug_bayes("statfile not found for %s, skipping", stcf->symbol);
+ cur = g_list_next(cur);
+ statfile_idx++;
+ continue;
+ }
+
+ /* Get or create runtime for this statfile */
+ auto *statfile_rt = rt; /* Use current runtime if it matches */
+ if (stcf != rt->stcf) {
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ statfile_rt = maybe_rt.value();
+ }
+ else {
+ msg_debug_bayes("runtime not found for %s, skipping", class_label);
+ cur = g_list_next(cur);
+ statfile_idx++;
+ continue;
+ }
+ }
+
+ /* Ensure correct statfile ID assignment */
+ statfile_rt->id = st->id;
+
+ /* Process token results using statfile index (1-based for Lua) */
+ lua_rawgeti(L, 5, statfile_idx + 1); /* Get token_results[statfile_idx + 1] */
+ if (lua_istable(L, -1)) {
+ /* Parse token results into statfile runtime */
+ auto *res = new std::vector<std::pair<int, float>>();
+
+ lua_pushnil(L); /* First key for iteration */
+ while (lua_next(L, -2) != 0) {
+ if (lua_istable(L, -1) && lua_objlen(L, -1) == 2) {
+ lua_rawgeti(L, -1, 1); /* token_index */
+ lua_rawgeti(L, -2, 2); /* token_count */
+
+ if (lua_isnumber(L, -2) && lua_isnumber(L, -1)) {
+ int token_idx = lua_tointeger(L, -2);
+ float token_count = lua_tonumber(L, -1);
+ res->emplace_back(token_idx, token_count);
+ }
+
+ lua_pop(L, 2); /* Pop token_index and token_count */
+ }
+ lua_pop(L, 1); /* Pop value, keep key for next iteration */
+ }
+
+ statfile_rt->set_results(res);
+ msg_debug_bayes("set %uz token results for statfile %s (label %s, id=%d)",
+ res->size(), stcf->symbol, class_label, st->id);
+ }
+ lua_pop(L, 1); /* Pop token_results[statfile_idx + 1] */
+
+ cur = g_list_next(cur);
+ statfile_idx++;
+ }
+ }
}
- /* Mark task as being processed */
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ /* Clean up stack */
+ lua_pop(L, 2); /* Pop learned_counts and token_results */
- /* Process all tokens */
+ /* Process tokens for all runtimes */
g_assert(rt->tokens != nullptr);
- rt->process_tokens(rt->tokens);
- opposite_rt_maybe.value()->process_tokens(rt->tokens);
+
+ /* Process tokens for all statfiles */
+ if (rt->stcf->clcf && rt->stcf->clcf->statfiles) {
+ GList *cur = rt->stcf->clcf->statfiles;
+
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ const char *class_label = get_class_label(stcf);
+
+ auto maybe_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ class_label);
+ if (maybe_rt) {
+ auto *statfile_rt = maybe_rt.value();
+ statfile_rt->process_tokens(rt->tokens);
+ }
+
+ cur = g_list_next(cur);
+ }
+ }
+ else {
+ /* Fallback: just process the main runtime */
+ rt->process_tokens(rt->tokens);
+ }
}
else {
/* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot classify task: %s",
- err_msg);
+ const char *err_msg = nullptr;
+ if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+ err_msg = lua_tostring(L, 3);
+ }
+ if (err_msg) {
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot classify task: %s", err_msg);
+ }
+ else {
+ rt->err = rspamd::util::error("unknown Redis script error", 500);
+ msg_err_task("cannot classify task: unknown Redis script error");
+ }
}
return 0;
@@ -929,7 +1247,42 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
rspamd_lua_task_push(L, task);
lua_pushstring(L, rt->redis_object_expanded);
lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
+
+ /* Send all class labels for multi-class support */
+ if (rt->stcf->clcf && rt->stcf->clcf->class_names &&
+ rt->stcf->clcf->class_names->len > 0) {
+ /* Multi-class: send array of class labels in deterministic order */
+ lua_createtable(L, rt->stcf->clcf->class_names->len, 0);
+ for (unsigned int i = 0; i < rt->stcf->clcf->class_names->len; i++) {
+ const char *class_name = (const char *) g_ptr_array_index(rt->stcf->clcf->class_names, i);
+ const char *class_label = nullptr;
+
+ /* Find the class label for this class name from any statfile with this class */
+ GList *cur = rt->stcf->clcf->statfiles;
+ while (cur) {
+ auto *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, class_name) == 0) {
+ class_label = get_class_label(stcf);
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+
+ if (class_label) {
+ lua_pushstring(L, class_label);
+ lua_rawseti(L, -2, i + 1); /* Lua arrays are 1-indexed */
+ }
+ }
+ }
+ else {
+ /* Binary classification: send both spam and ham labels for optimization */
+ lua_createtable(L, 2, 0);
+ lua_pushstring(L, "H"); /* ham */
+ lua_rawseti(L, -2, 1);
+ lua_pushstring(L, "S"); /* spam */
+ lua_rawseti(L, -2, 2);
+ }
+
lua_new_text(L, tokens_buf, tokens_len, false);
/* Store rt in random cookie */
@@ -979,13 +1332,31 @@ rspamd_redis_learned(lua_State *L)
bool result = lua_toboolean(L, 2);
if (result) {
- /* TODO: write it */
+ /* Learning successful - no complex data to process like in classification */
+ msg_debug_bayes("learned tokens successfully in Redis for symbol %s, class %s",
+ rt->stcf->symbol, get_class_label(rt->stcf));
+
+ /* Clear any previous error state */
+ rt->err = std::nullopt;
+
+ /* Learning operations don't return data structures to process,
+ * they just update Redis state. Success means the Redis script
+ * completed without errors. */
}
else {
/* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot learn task: %s", err_msg);
+ const char *err_msg = nullptr;
+ if (lua_gettop(L) >= 3 && lua_isstring(L, 3)) {
+ err_msg = lua_tostring(L, 3);
+ }
+ if (err_msg) {
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot learn task: %s", err_msg);
+ }
+ else {
+ rt->err = rspamd::util::error("unknown Redis script error", 500);
+ msg_err_task("cannot learn task: unknown Redis script error");
+ }
}
return 0;
@@ -1028,7 +1399,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
rspamd_lua_task_push(L, task);
lua_pushstring(L, rt->redis_object_expanded);
lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
+ lua_pushstring(L, get_class_label(rt->stcf)); /* Pass class label instead of boolean */
lua_pushstring(L, rt->stcf->symbol);
/* Detect unlearn */
@@ -1056,6 +1427,8 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
lua_new_text(L, text_tokens_buf, text_tokens_len, false);
}
+ msg_debug_bayes("called lua learn script for %s (cookie=%s)", rt->stcf->symbol, cookie);
+
if (lua_pcall(L, nargs, 0, err_idx) != 0) {
msg_err_task("call to script failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);