aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat/learn_cache/redis_cache.cxx
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat/learn_cache/redis_cache.cxx')
-rw-r--r--src/libstat/learn_cache/redis_cache.cxx84
1 files changed, 69 insertions, 15 deletions
diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx
index 0de5cd094..afefeadcd 100644
--- a/src/libstat/learn_cache/redis_cache.cxx
+++ b/src/libstat/learn_cache/redis_cache.cxx
@@ -152,6 +152,33 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
return (void *) ctx;
}
+/* Get class ID using rspamd_cryptobox_fast_hash */
+static uint64_t
+rspamd_stat_cache_get_class_id(const char *class_name)
+{
+ if (!class_name) {
+ return 0;
+ }
+
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ return 1;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ return 0;
+ }
+ else {
+ /* For other classes, use rspamd_cryptobox_fast_hash */
+ uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0);
+
+ /* Ensure we don't get 0 or 1 (reserved for ham/spam) */
+ if (hash == 0 || hash == 1) {
+ hash += 2;
+ }
+
+ return hash;
+ }
+}
+
static int
rspamd_stat_cache_checked(lua_State *L)
{
@@ -161,23 +188,39 @@ rspamd_stat_cache_checked(lua_State *L)
if (res) {
auto val = lua_tointeger(L, 3);
- if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
- (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
- /* Already learned */
- msg_info_task("<%s> has been already "
- "learned as %s, ignore it",
- MESSAGE_FIELD(task, message_id),
- (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
- task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ /* Get the class being learned */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (!autolearn_class) {
+ /* Fallback to binary flags for backward compatibility */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ autolearn_class = "spam";
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ autolearn_class = "ham";
+ }
}
- else {
- /* Unlearn flag */
- task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+
+ if (autolearn_class) {
+ uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class);
+
+ if ((uint64_t) val == expected_id) {
+ /* Already learned */
+ msg_info_task("<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ autolearn_class);
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ }
+ else {
+ /* Different class learned, unlearn flag */
+ msg_debug_task("<%s> cached value %ld != expected %lu for class %s, will unlearn",
+ MESSAGE_FIELD(task, message_id),
+ val, expected_id, autolearn_class);
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ }
}
}
- /* Ignore errors for now, as we can do nothing about them at the moment */
-
return 0;
}
@@ -235,9 +278,20 @@ int rspamd_stat_cache_redis_learn(struct rspamd_task *task,
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
rspamd_lua_task_push(L, task);
lua_pushstring(L, h);
- lua_pushboolean(L, is_spam);
- if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ /* Get the class being learned - prefer multiclass over binary */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (!autolearn_class) {
+ /* Fallback to binary flag for backward compatibility */
+ autolearn_class = is_spam ? "spam" : "ham";
+ }
+
+ /* Push class name and class ID */
+ lua_pushstring(L, autolearn_class);
+ uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class);
+ lua_pushinteger(L, class_id);
+
+ if (lua_pcall(L, 4, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return RSPAMD_LEARN_IGNORE;