diff options
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 4 | ||||
-rw-r--r-- | src/libstat/learn_cache/redis_cache.cxx | 60 |
2 files changed, 59 insertions, 5 deletions
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 94576224d..375fa6c9b 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2024 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1032,7 +1032,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, lua_pushcclosure(L, &rspamd_redis_learned, 1); if (lua_pcall(L, nargs, 0, err_idx) != 0) { - msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + msg_err_task("call to script failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); return FALSE; } diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx index b774e626e..e53bd4df0 100644 --- a/src/libstat/learn_cache/redis_cache.cxx +++ b/src/libstat/learn_cache/redis_cache.cxx @@ -205,14 +205,37 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task, return NULL; } - if (!learn) { + /* On check, we produce words_hash variable, on learn it is guaranteed to be set */ rspamd_stat_cache_redis_generate_id(task); } return (void *) ctx; } +static gint +rspamd_stat_cache_checked(lua_State *L) +{ + auto *task = lua_check_task(L, 1); + auto val = lua_tointeger(L, 2); + + if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || + (val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { + /* Already learned */ + msg_info_task("<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + } + else if (val != 0) { + /* Unlearn flag */ + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + } + + return 0; +} + gint rspamd_stat_cache_redis_check(struct rspamd_task *task, gboolean is_spam, gpointer runtime) @@ -224,6 +247,24 @@ gint rspamd_stat_cache_redis_check(struct rspamd_task *task, return RSPAMD_LEARN_IGNORE; } + auto *L = ctx->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref); + rspamd_lua_task_push(L, task); + lua_pushstring(L, h); + + lua_pushcclosure(L, &rspamd_stat_cache_checked, 0); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return RSPAMD_LEARN_IGNORE; + } + /* We need to return OK every time */ return RSPAMD_LEARN_OK; } @@ -233,7 +274,6 @@ gint rspamd_stat_cache_redis_learn(struct rspamd_task *task, gpointer runtime) { auto *ctx = (struct rspamd_redis_cache_ctx *) runtime; - gint flag; if (rspamd_session_blocked(task->s)) { return RSPAMD_LEARN_IGNORE; @@ -241,8 +281,22 @@ gint rspamd_stat_cache_redis_learn(struct rspamd_task *task, auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash"); g_assert(h != NULL); + auto *L = ctx->L; - flag = (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? 1 : -1; + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref); + rspamd_lua_task_push(L, task); + lua_pushstring(L, h); + lua_pushboolean(L, is_spam); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return RSPAMD_LEARN_IGNORE; + } /* We need to return OK every time */ return RSPAMD_LEARN_OK; |