diff options
Diffstat (limited to 'src/libstat/learn_cache/redis_cache.cxx')
-rw-r--r-- | src/libstat/learn_cache/redis_cache.cxx | 84 |
1 files changed, 69 insertions, 15 deletions
diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx index 0de5cd094..afefeadcd 100644 --- a/src/libstat/learn_cache/redis_cache.cxx +++ b/src/libstat/learn_cache/redis_cache.cxx @@ -152,6 +152,33 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task, return (void *) ctx; } +/* Get class ID using rspamd_cryptobox_fast_hash */ +static uint64_t +rspamd_stat_cache_get_class_id(const char *class_name) +{ + if (!class_name) { + return 0; + } + + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + return 1; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + return 0; + } + else { + /* For other classes, use rspamd_cryptobox_fast_hash */ + uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0); + + /* Ensure we don't get 0 or 1 (reserved for ham/spam) */ + if (hash == 0 || hash == 1) { + hash += 2; + } + + return hash; + } +} + static int rspamd_stat_cache_checked(lua_State *L) { @@ -161,23 +188,39 @@ rspamd_stat_cache_checked(lua_State *L) if (res) { auto val = lua_tointeger(L, 3); - if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || - (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { - /* Already learned */ - msg_info_task("<%s> has been already " - "learned as %s, ignore it", - MESSAGE_FIELD(task, message_id), - (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); - task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + /* Get the class being learned */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (!autolearn_class) { + /* Fallback to binary flags for backward compatibility */ + if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) { + autolearn_class = "spam"; + } + else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + autolearn_class = "ham"; + } } - else { - /* Unlearn flag */ - task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + + if (autolearn_class) { + uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class); + + if ((uint64_t) val == expected_id) { + /* Already learned */ + msg_info_task("<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + autolearn_class); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + } + else { + /* Different class learned, unlearn flag */ + msg_debug_task("<%s> cached value %ld != expected %lu for class %s, will unlearn", + MESSAGE_FIELD(task, message_id), + val, expected_id, autolearn_class); + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + } } } - /* Ignore errors for now, as we can do nothing about them at the moment */ - return 0; } @@ -235,9 +278,20 @@ int rspamd_stat_cache_redis_learn(struct rspamd_task *task, lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref); rspamd_lua_task_push(L, task); lua_pushstring(L, h); - lua_pushboolean(L, is_spam); - if (lua_pcall(L, 3, 0, err_idx) != 0) { + /* Get the class being learned - prefer multiclass over binary */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (!autolearn_class) { + /* Fallback to binary flag for backward compatibility */ + autolearn_class = is_spam ? "spam" : "ham"; + } + + /* Push class name and class ID */ + lua_pushstring(L, autolearn_class); + uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class); + lua_pushinteger(L, class_id); + + if (lua_pcall(L, 4, 0, err_idx) != 0) { msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); return RSPAMD_LEARN_IGNORE; |