aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-01-17 21:27:24 +0600
committerGitHub <noreply@github.com>2024-01-17 21:27:24 +0600
commit01d182c86318c9776dcaa7a6c64a19c8c0783b9d (patch)
treecc0a8376c8793ff424fb134b30bf7ede36bc159c
parent2e1d00595961a2b81358164899ffee75a6cbea8a (diff)
parenta00c667b4ce0ce8d2f2454787854d191f2d5931b (diff)
downloadrspamd-01d182c86318c9776dcaa7a6c64a19c8c0783b9d.tar.gz
rspamd-01d182c86318c9776dcaa7a6c64a19c8c0783b9d.zip
Merge pull request #4774 from rspamd/vstakhov-redis-cache-rework
Rewrite redis_cache logic in statistics
-rw-r--r--lualib/lua_bayes_redis.lua98
-rw-r--r--lualib/redis_scripts/bayes_cache_check.lua20
-rw-r--r--lualib/redis_scripts/bayes_cache_learn.lua61
-rw-r--r--src/libstat/CMakeLists.txt2
-rw-r--r--src/libstat/backends/redis_backend.cxx4
-rw-r--r--src/libstat/learn_cache/redis_cache.c535
-rw-r--r--src/libstat/learn_cache/redis_cache.cxx315
7 files changed, 488 insertions, 547 deletions
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua
index 5ad5c3514..753399705 100644
--- a/lualib/lua_bayes_redis.lua
+++ b/lualib/lua_bayes_redis.lua
@@ -20,6 +20,7 @@ local exports = {}
local lua_redis = require "lua_redis"
local logger = require "rspamd_logger"
local lua_util = require "lua_util"
+local ucl = require "ucl"
local N = "bayes"
@@ -54,24 +55,19 @@ local function gen_learn_functor(redis_params, learn_script_id)
if maybe_text_tokens then
lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = false, key = expanded_key },
+ { task = task, is_write = true, key = expanded_key },
learn_redis_cb,
{ expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
else
lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = false, key = expanded_key },
+ { task = task, is_write = true, key = expanded_key },
learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
end
end
end
----
---- Init bayes classifier
---- @param classifier_ucl ucl of the classifier config
---- @param statfile_ucl ucl of the statfile config
---- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
-exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
+local function load_redis_params(classifier_ucl, statfile_ucl)
local redis_params
-- Try load from statfile options
@@ -108,6 +104,22 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
return nil
end
+ return redis_params
+end
+
+---
+--- Init bayes classifier
+--- @param classifier_ucl ucl of the classifier config
+--- @param statfile_ucl ucl of the statfile config
+--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
+exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
+
+ local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
+
+ if not redis_params then
+ return nil
+ end
+
local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params)
local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params)
local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params)
@@ -125,7 +137,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
local function stat_redis_cb(err, data)
- -- TODO: write this function
lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
if err then
@@ -161,4 +172,73 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
end
+local function gen_cache_check_functor(redis_params, check_script_id, conf)
+ local packed_conf = ucl.to_format(conf, 'msgpack')
+ return function(task, cache_id, callback)
+
+ local function classify_redis_cb(err, data)
+ lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
+ if err then
+ callback(task, false, err)
+ else
+ if type(data) == 'number' then
+ callback(task, true, data)
+ else
+ callback(task, false, 'not found')
+ end
+ end
+ end
+
+ lua_util.debugm(N, task, 'checking cache: %s', cache_id)
+ lua_redis.exec_redis_script(check_script_id,
+ { task = task, is_write = false, key = cache_id },
+ classify_redis_cb, { cache_id, packed_conf })
+ end
+end
+
+local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
+ local packed_conf = ucl.to_format(conf, 'msgpack')
+ return function(task, cache_id, is_spam)
+ local function learn_redis_cb(err, data)
+ lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
+ end
+
+ lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
+ lua_redis.exec_redis_script(learn_script_id,
+ { task = task, is_write = true, key = cache_id },
+ learn_redis_cb,
+ { cache_id, is_spam and "1" or "0", packed_conf })
+
+ end
+end
+
+exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
+ local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
+
+ if not redis_params then
+ return nil
+ end
+
+ local default_conf = {
+ cache_prefix = "learned_ids",
+ cache_max_elt = 10000, -- Maximum number of elements in the cache key
+ cache_max_keys = 5, -- Maximum number of keys in the cache
+ cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value)
+ }
+
+ local conf = lua_util.override_defaults(default_conf, classifier_ucl)
+ -- Clean all not known configurations
+ for k, _ in pairs(conf) do
+ if default_conf[k] == nil then
+ conf[k] = nil
+ end
+ end
+
+ local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params)
+ local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)
+
+ return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params,
+ learn_script_id, conf)
+end
+
return exports
diff --git a/lualib/redis_scripts/bayes_cache_check.lua b/lualib/redis_scripts/bayes_cache_check.lua
new file mode 100644
index 000000000..f1ffc2b84
--- /dev/null
+++ b/lualib/redis_scripts/bayes_cache_check.lua
@@ -0,0 +1,20 @@
+-- Lua script to perform cache checking for bayes classification
+-- This script accepts the following parameters:
+-- key1 - cache id
+-- key2 - configuration table in message pack
+
+local cache_id = KEYS[1]
+local conf = cmsgpack.unpack(KEYS[2])
+cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
+
+-- Try each prefix that is in Redis
+for i = 0, conf.cache_max_keys do
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local have = redis.call('HGET', prefix, cache_id)
+
+ if have then
+ return tonumber(have)
+ end
+end
+
+return nil
diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua
new file mode 100644
index 000000000..8811f3c33
--- /dev/null
+++ b/lualib/redis_scripts/bayes_cache_learn.lua
@@ -0,0 +1,61 @@
+-- Lua script to perform cache checking for bayes classification
+-- This script accepts the following parameters:
+-- key1 - cache id
+-- key3 - is spam (1 or 0)
+-- key3 - configuration table in message pack
+
+local cache_id = KEYS[1]
+local is_spam = KEYS[2]
+local conf = cmsgpack.unpack(KEYS[3])
+cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
+
+-- Try each prefix that is in Redis (as some other instance might have set it)
+for i = 0, conf.cache_max_keys do
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local have = redis.call('HGET', prefix, cache_id)
+
+ if have then
+ -- Already in cache
+ return false
+ end
+end
+
+local added = false
+local lim = conf.cache_max_elt
+for i = 0, conf.cache_max_keys do
+ if not added then
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local count = redis.call('HLEN', prefix)
+
+ if count < lim then
+ -- We can add it to this prefix
+ redis.call('HSET', prefix, cache_id, is_spam)
+ added = true
+ end
+ end
+end
+
+if not added then
+ -- Need to expire some keys
+ local expired = false
+ for i = 0, conf.cache_max_keys do
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local exists = redis.call('EXISTS', prefix)
+
+ if exists then
+ if expired then
+ redis.call('DEL', prefix)
+ redis.call('HSET', prefix, cache_id, is_spam)
+
+ -- Do not expire anything else
+ expired = true
+ elseif i > 0 then
+ -- Move key to a shorter prefix, so we will rotate them eventually from lower to upper
+ local new_prefix = conf.cache_prefix .. string.rep("X", i - 1)
+ redis.call('RENAME', prefix, new_prefix)
+ end
+ end
+ end
+end
+
+return true \ No newline at end of file
diff --git a/src/libstat/CMakeLists.txt b/src/libstat/CMakeLists.txt
index 4866d2433..64d572a57 100644
--- a/src/libstat/CMakeLists.txt
+++ b/src/libstat/CMakeLists.txt
@@ -15,7 +15,7 @@ SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c
${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx)
SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c
- ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.c)
+ ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx)
SET(RSPAMD_STAT ${LIBSTATSRC}
${TOKENIZERSSRC}
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index 94576224d..375fa6c9b 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -1,5 +1,5 @@
/*
- * Copyright 2023 Vsevolod Stakhov
+ * Copyright 2024 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -1032,7 +1032,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
lua_pushcclosure(L, &rspamd_redis_learned, 1);
if (lua_pcall(L, nargs, 0, err_idx) != 0) {
- msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+ msg_err_task("call to script failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return FALSE;
}
diff --git a/src/libstat/learn_cache/redis_cache.c b/src/libstat/learn_cache/redis_cache.c
deleted file mode 100644
index 0a378c8a3..000000000
--- a/src/libstat/learn_cache/redis_cache.c
+++ /dev/null
@@ -1,535 +0,0 @@
-/*
- * Copyright 2023 Vsevolod Stakhov
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "config.h"
-#include "learn_cache.h"
-#include "rspamd.h"
-#include "stat_api.h"
-#include "stat_internal.h"
-#include "cryptobox.h"
-#include "ucl.h"
-#include "hiredis.h"
-#include "adapters/libev.h"
-#include "lua/lua_common.h"
-#include "libmime/message.h"
-
-#define REDIS_DEFAULT_TIMEOUT 0.5
-#define REDIS_STAT_TIMEOUT 30
-#define REDIS_DEFAULT_PORT 6379
-#define DEFAULT_REDIS_KEY "learned_ids"
-
-static const gchar *M = "redis learn cache";
-
-struct rspamd_redis_cache_ctx {
- lua_State *L;
- struct rspamd_statfile_config *stcf;
- const gchar *username;
- const gchar *password;
- const gchar *dbname;
- const gchar *redis_object;
- gdouble timeout;
- gint conf_ref;
-};
-
-struct rspamd_redis_cache_runtime {
- struct rspamd_redis_cache_ctx *ctx;
- struct rspamd_task *task;
- struct upstream *selected;
- ev_timer timer_ev;
- redisAsyncContext *redis;
- gboolean has_event;
-};
-
-static GQuark
-rspamd_stat_cache_redis_quark(void)
-{
- return g_quark_from_static_string(M);
-}
-
-static inline struct upstream_list *
-rspamd_redis_get_servers(struct rspamd_redis_cache_ctx *ctx,
- const gchar *what)
-{
- lua_State *L = ctx->L;
- struct upstream_list *res;
-
- lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->conf_ref);
- lua_pushstring(L, what);
- lua_gettable(L, -2);
- res = *((struct upstream_list **) lua_touserdata(L, -1));
- lua_settop(L, 0);
-
- return res;
-}
-
-static void
-rspamd_redis_cache_maybe_auth(struct rspamd_redis_cache_ctx *ctx,
- redisAsyncContext *redis)
-{
- if (ctx->username) {
- if (ctx->password) {
- redisAsyncCommand(redis, NULL, NULL, "AUTH %s %s", ctx->username, ctx->password);
- }
- else {
- msg_warn("Redis requires a password when username is supplied");
- }
- }
- else if (ctx->password) {
- redisAsyncCommand(redis, NULL, NULL, "AUTH %s", ctx->password);
- }
- if (ctx->dbname) {
- redisAsyncCommand(redis, NULL, NULL, "SELECT %s", ctx->dbname);
- }
-}
-
-/* Called on connection termination */
-static void
-rspamd_redis_cache_fin(gpointer data)
-{
- struct rspamd_redis_cache_runtime *rt = data;
- redisAsyncContext *redis;
-
- rt->has_event = FALSE;
- ev_timer_stop(rt->task->event_loop, &rt->timer_ev);
-
- if (rt->redis) {
- redis = rt->redis;
- rt->redis = NULL;
- /* This calls for all callbacks pending */
- redisAsyncFree(redis);
- }
-}
-
-static void
-rspamd_redis_cache_timeout(EV_P_ ev_timer *w, int revents)
-{
- struct rspamd_redis_cache_runtime *rt =
- (struct rspamd_redis_cache_runtime *) w->data;
- struct rspamd_task *task;
-
- task = rt->task;
-
- msg_err_task("connection to redis server %s timed out",
- rspamd_upstream_name(rt->selected));
- rspamd_upstream_fail(rt->selected, FALSE, "timeout");
-
- if (rt->has_event) {
- rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
- }
-}
-
-/* Called when we have checked the specified message id */
-static void
-rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv)
-{
- struct rspamd_redis_cache_runtime *rt = priv;
- redisReply *reply = r;
- struct rspamd_task *task;
- glong val = 0;
-
- task = rt->task;
-
- if (c->err == 0) {
- if (reply) {
- if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) {
- val = reply->integer;
- }
- else if (reply->type == REDIS_REPLY_STRING) {
- rspamd_strtol(reply->str, reply->len, &val);
- }
- else {
- if (reply->type == REDIS_REPLY_ERROR) {
- msg_err_task("cannot learn %s: redis error: \"%s\"",
- rt->ctx->stcf->symbol, reply->str);
- }
- else if (reply->type != REDIS_REPLY_NIL) {
- msg_err_task("bad learned type for %s: %d",
- rt->ctx->stcf->symbol, reply->type);
- }
-
- val = 0;
- }
- }
-
- if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
- (val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
- /* Already learned */
- msg_info_task("<%s> has been already "
- "learned as %s, ignore it",
- MESSAGE_FIELD(task, message_id),
- (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
- task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
- }
- else if (val != 0) {
- /* Unlearn flag */
- task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
- }
-
- rspamd_upstream_ok(rt->selected);
- }
- else {
- rspamd_upstream_fail(rt->selected, FALSE, c->errstr);
- }
-
- if (rt->has_event) {
- rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
- }
-}
-
-/* Called when we have learned the specified message id */
-static void
-rspamd_stat_cache_redis_set(redisAsyncContext *c, gpointer r, gpointer priv)
-{
- struct rspamd_redis_cache_runtime *rt = priv;
- struct rspamd_task *task;
-
- task = rt->task;
-
- if (c->err == 0) {
- /* XXX: we ignore results here */
- rspamd_upstream_ok(rt->selected);
- }
- else {
- rspamd_upstream_fail(rt->selected, FALSE, c->errstr);
- }
-
- if (rt->has_event) {
- rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
- }
-}
-
-static void
-rspamd_stat_cache_redis_generate_id(struct rspamd_task *task)
-{
- rspamd_cryptobox_hash_state_t st;
- rspamd_token_t *tok;
- guint i;
- guchar out[rspamd_cryptobox_HASHBYTES];
- gchar *b32out;
- gchar *user = NULL;
-
- rspamd_cryptobox_hash_init(&st, NULL, 0);
-
- user = rspamd_mempool_get_variable(task->task_pool, "stat_user");
- /* Use dedicated hash space for per users cache */
- if (user != NULL) {
- rspamd_cryptobox_hash_update(&st, user, strlen(user));
- }
-
- for (i = 0; i < task->tokens->len; i++) {
- tok = g_ptr_array_index(task->tokens, i);
- rspamd_cryptobox_hash_update(&st, (guchar *) &tok->data,
- sizeof(tok->data));
- }
-
- rspamd_cryptobox_hash_final(&st, out);
-
- b32out = rspamd_mempool_alloc(task->task_pool,
- sizeof(out) * 8 / 5 + 3);
- i = rspamd_encode_base32_buf(out, sizeof(out), b32out,
- sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);
-
- if (i > 0) {
- /* Zero terminate */
- b32out[i] = '\0';
- }
-
- rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL);
-}
-
-gpointer
-rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
- struct rspamd_config *cfg,
- struct rspamd_statfile *st,
- const ucl_object_t *cf)
-{
- struct rspamd_redis_cache_ctx *cache_ctx;
- struct rspamd_statfile_config *stf = st->stcf;
- const ucl_object_t *obj;
- gboolean ret = FALSE;
- lua_State *L = (lua_State *) cfg->lua_state;
- gint conf_ref = -1;
-
- cache_ctx = g_malloc0(sizeof(*cache_ctx));
- cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
- cache_ctx->L = L;
-
- /* First search in backend configuration */
- obj = ucl_object_lookup(st->classifier->cfg->opts, "backend");
- if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) {
- ret = rspamd_lua_try_load_redis(L, obj, cfg, &conf_ref);
- }
-
- /* Now try statfiles config */
- if (!ret && stf->opts) {
- ret = rspamd_lua_try_load_redis(L, stf->opts, cfg, &conf_ref);
- }
-
- /* Now try classifier config */
- if (!ret && st->classifier->cfg->opts) {
- ret = rspamd_lua_try_load_redis(L, st->classifier->cfg->opts, cfg, &conf_ref);
- }
-
- /* Now try global redis settings */
- if (!ret) {
- obj = ucl_object_lookup(cfg->cfg_ucl_obj, "redis");
-
- if (obj) {
- const ucl_object_t *specific_obj;
-
- specific_obj = ucl_object_lookup(obj, "statistics");
-
- if (specific_obj) {
- ret = rspamd_lua_try_load_redis(L,
- specific_obj, cfg, &conf_ref);
- }
- else {
- ret = rspamd_lua_try_load_redis(L,
- obj, cfg, &conf_ref);
- }
- }
- }
-
- if (!ret) {
- msg_err_config("cannot init redis cache for %s", stf->symbol);
- g_free(cache_ctx);
- return NULL;
- }
-
- obj = ucl_object_lookup(st->classifier->cfg->opts, "cache_key");
-
- if (obj) {
- cache_ctx->redis_object = ucl_object_tostring(obj);
- }
- else {
- cache_ctx->redis_object = DEFAULT_REDIS_KEY;
- }
-
- cache_ctx->conf_ref = conf_ref;
-
- /* Check some common table values */
- lua_rawgeti(L, LUA_REGISTRYINDEX, conf_ref);
-
- lua_pushstring(L, "timeout");
- lua_gettable(L, -2);
- if (lua_type(L, -1) == LUA_TNUMBER) {
- cache_ctx->timeout = lua_tonumber(L, -1);
- }
- lua_pop(L, 1);
-
- lua_pushstring(L, "db");
- lua_gettable(L, -2);
- if (lua_type(L, -1) == LUA_TSTRING) {
- cache_ctx->dbname = rspamd_mempool_strdup(cfg->cfg_pool,
- lua_tostring(L, -1));
- }
- lua_pop(L, 1);
-
- lua_pushstring(L, "username");
- lua_gettable(L, -2);
- if (lua_type(L, -1) == LUA_TSTRING) {
- cache_ctx->username = rspamd_mempool_strdup(cfg->cfg_pool,
- lua_tostring(L, -1));
- }
- lua_pop(L, 1);
-
- lua_pushstring(L, "password");
- lua_gettable(L, -2);
- if (lua_type(L, -1) == LUA_TSTRING) {
- cache_ctx->password = rspamd_mempool_strdup(cfg->cfg_pool,
- lua_tostring(L, -1));
- }
- lua_pop(L, 1);
-
- lua_settop(L, 0);
-
- cache_ctx->stcf = stf;
-
- return (gpointer) cache_ctx;
-}
-
-gpointer
-rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
- gpointer c, gboolean learn)
-{
- struct rspamd_redis_cache_ctx *ctx = c;
- struct rspamd_redis_cache_runtime *rt;
- struct upstream *up;
- struct upstream_list *ups;
- rspamd_inet_addr_t *addr;
-
- g_assert(ctx != NULL);
-
- if (task->tokens == NULL || task->tokens->len == 0) {
- return NULL;
- }
-
- if (learn) {
- ups = rspamd_redis_get_servers(ctx, "write_servers");
-
- if (!ups) {
- msg_err_task("no write servers defined for %s, cannot learn",
- ctx->stcf->symbol);
- return NULL;
- }
-
- up = rspamd_upstream_get(ups,
- RSPAMD_UPSTREAM_MASTER_SLAVE,
- NULL,
- 0);
- }
- else {
- ups = rspamd_redis_get_servers(ctx, "read_servers");
-
- if (!ups) {
- msg_err_task("no read servers defined for %s, cannot check",
- ctx->stcf->symbol);
- return NULL;
- }
-
- up = rspamd_upstream_get(ups,
- RSPAMD_UPSTREAM_ROUND_ROBIN,
- NULL,
- 0);
- }
-
- if (up == NULL) {
- msg_err_task("no upstreams reachable");
- return NULL;
- }
-
- rt = rspamd_mempool_alloc0(task->task_pool, sizeof(*rt));
- rt->selected = up;
- rt->task = task;
- rt->ctx = ctx;
-
- addr = rspamd_upstream_addr_next(up);
- g_assert(addr != NULL);
-
- if (rspamd_inet_address_get_af(addr) == AF_UNIX) {
- rt->redis = redisAsyncConnectUnix(rspamd_inet_address_to_string(addr));
- }
- else {
- rt->redis = redisAsyncConnect(rspamd_inet_address_to_string(addr),
- rspamd_inet_address_get_port(addr));
- }
-
- if (rt->redis == NULL) {
- msg_warn_task("cannot connect to redis server %s: %s",
- rspamd_inet_address_to_string_pretty(addr),
- strerror(errno));
-
- return NULL;
- }
- else if (rt->redis->err != REDIS_OK) {
- msg_warn_task("cannot connect to redis server %s: %s",
- rspamd_inet_address_to_string_pretty(addr),
- rt->redis->errstr);
- redisAsyncFree(rt->redis);
- rt->redis = NULL;
-
- return NULL;
- }
-
- redisLibevAttach(task->event_loop, rt->redis);
-
- /* Now check stats */
- rt->timer_ev.data = rt;
- ev_timer_init(&rt->timer_ev, rspamd_redis_cache_timeout,
- rt->ctx->timeout, 0.0);
- rspamd_redis_cache_maybe_auth(ctx, rt->redis);
-
- if (!learn) {
- rspamd_stat_cache_redis_generate_id(task);
- }
-
- return rt;
-}
-
-gint rspamd_stat_cache_redis_check(struct rspamd_task *task,
- gboolean is_spam,
- gpointer runtime)
-{
- struct rspamd_redis_cache_runtime *rt = runtime;
- gchar *h;
-
- if (rspamd_session_blocked(task->s)) {
- return RSPAMD_LEARN_IGNORE;
- }
-
- h = rspamd_mempool_get_variable(task->task_pool, "words_hash");
-
- if (h == NULL) {
- return RSPAMD_LEARN_IGNORE;
- }
-
- if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_get, rt,
- "HGET %s %s",
- rt->ctx->redis_object, h) == REDIS_OK) {
- rspamd_session_add_event(task->s,
- rspamd_redis_cache_fin,
- rt,
- M);
- ev_timer_start(rt->task->event_loop, &rt->timer_ev);
- rt->has_event = TRUE;
- }
-
- /* We need to return OK every time */
- return RSPAMD_LEARN_OK;
-}
-
-gint rspamd_stat_cache_redis_learn(struct rspamd_task *task,
- gboolean is_spam,
- gpointer runtime)
-{
- struct rspamd_redis_cache_runtime *rt = runtime;
- gchar *h;
- gint flag;
-
- if (rt == NULL || rt->ctx == NULL || rspamd_session_blocked(task->s)) {
- return RSPAMD_LEARN_IGNORE;
- }
-
- h = rspamd_mempool_get_variable(task->task_pool, "words_hash");
- g_assert(h != NULL);
-
- flag = (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? 1 : -1;
-
- if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_set, rt,
- "HSET %s %s %d",
- rt->ctx->redis_object, h, flag) == REDIS_OK) {
- rspamd_session_add_event(task->s,
- rspamd_redis_cache_fin, rt, M);
- ev_timer_start(rt->task->event_loop, &rt->timer_ev);
- rt->has_event = TRUE;
- }
-
- /* We need to return OK every time */
- return RSPAMD_LEARN_OK;
-}
-
-void rspamd_stat_cache_redis_close(gpointer c)
-{
- struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *) c;
- lua_State *L;
-
- L = ctx->L;
-
- if (ctx->conf_ref) {
- luaL_unref(L, LUA_REGISTRYINDEX, ctx->conf_ref);
- }
-
- g_free(ctx);
-}
diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx
new file mode 100644
index 000000000..b67df70a2
--- /dev/null
+++ b/src/libstat/learn_cache/redis_cache.cxx
@@ -0,0 +1,315 @@
+/*
+ * Copyright 2024 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+// Include early to avoid `extern "C"` issues
+#include "lua/lua_common.h"
+#include "learn_cache.h"
+#include "rspamd.h"
+#include "stat_api.h"
+#include "stat_internal.h"
+#include "cryptobox.h"
+#include "ucl.h"
+#include "libmime/message.h"
+
+#define DEFAULT_REDIS_KEY "learned_ids"
+
+struct rspamd_redis_cache_ctx {
+ lua_State *L;
+ struct rspamd_statfile_config *stcf;
+ std::string redis_object = DEFAULT_REDIS_KEY;
+ int check_ref = -1;
+ int learn_ref = -1;
+
+ rspamd_redis_cache_ctx() = delete;
+ explicit rspamd_redis_cache_ctx(lua_State *L)
+ : L(L)
+ {
+ }
+
+ ~rspamd_redis_cache_ctx()
+ {
+ if (check_ref != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, check_ref);
+ }
+
+ if (learn_ref != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, learn_ref);
+ }
+ }
+};
+
+#if 0
+/* Called when we have checked the specified message id */
+static void
+rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv)
+{
+ struct rspamd_redis_cache_runtime *rt = priv;
+ redisReply *reply = r;
+ struct rspamd_task *task;
+ glong val = 0;
+
+ task = rt->task;
+
+ if (c->err == 0) {
+ if (reply) {
+ if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) {
+ val = reply->integer;
+ }
+ else if (reply->type == REDIS_REPLY_STRING) {
+ rspamd_strtol(reply->str, reply->len, &val);
+ }
+ else {
+ if (reply->type == REDIS_REPLY_ERROR) {
+ msg_err_task("cannot learn %s: redis error: \"%s\"",
+ rt->ctx->stcf->symbol, reply->str);
+ }
+ else if (reply->type != REDIS_REPLY_NIL) {
+ msg_err_task("bad learned type for %s: %d",
+ rt->ctx->stcf->symbol, reply->type);
+ }
+
+ val = 0;
+ }
+ }
+
+ if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
+ (val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
+ /* Already learned */
+ msg_info_task("<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ }
+ else if (val != 0) {
+ /* Unlearn flag */
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ }
+
+ rspamd_upstream_ok(rt->selected);
+ }
+ else {
+ rspamd_upstream_fail(rt->selected, FALSE, c->errstr);
+ }
+
+ if (rt->has_event) {
+ rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
+ }
+}
+#endif
+
+static void
+rspamd_stat_cache_redis_generate_id(struct rspamd_task *task)
+{
+ rspamd_cryptobox_hash_state_t st;
+ rspamd_cryptobox_hash_init(&st, NULL, 0);
+
+ const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user");
+ /* Use dedicated hash space for per users cache */
+ if (user != NULL) {
+ rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user));
+ }
+
+ for (auto i = 0; i < task->tokens->len; i++) {
+ const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i);
+ rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data,
+ sizeof(tok->data));
+ }
+
+ guchar out[rspamd_cryptobox_HASHBYTES];
+ rspamd_cryptobox_hash_final(&st, out);
+
+ auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool,
+ sizeof(out) * 8 / 5 + 3, char);
+ auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out,
+ sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);
+
+ if (out_sz > 0) {
+ /* Zero terminate */
+ b32out[out_sz] = '\0';
+ rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL);
+ }
+}
+
+gpointer
+rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st,
+ const ucl_object_t *cf)
+{
+ std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg));
+
+ auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+ lua_settop(L, 0);
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ auto err_idx = lua_gettop(L);
+
+ /* Obtain function */
+ if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) {
+ msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache");
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /* Push arguments */
+ ucl_object_push_lua(L, st->classifier->cfg->opts, false);
+ ucl_object_push_lua(L, st->stcf->opts, false);
+
+ if (lua_pcall(L, 2, 2, err_idx) != 0) {
+ msg_err("call to lua_bayes_init_cache "
+ "script failed: %s",
+ lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /*
+ * Results are in the stack:
+ * top - 1 - check function (idx = -2)
+ * top - learn function (idx = -1)
+ */
+ lua_pushvalue(L, -2);
+ cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_pushvalue(L, -1);
+ cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_settop(L, err_idx - 1);
+
+ return (gpointer) cache_ctx.release();
+}
+
+gpointer
+rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
+ gpointer c, gboolean learn)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) c;
+
+ if (task->tokens == NULL || task->tokens->len == 0) {
+ return NULL;
+ }
+
+ if (!learn) {
+ /* On check, we produce words_hash variable, on learn it is guaranteed to be set */
+ rspamd_stat_cache_redis_generate_id(task);
+ }
+
+ return (void *) ctx;
+}
+
+static gint
+rspamd_stat_cache_checked(lua_State *L)
+{
+ auto *task = lua_check_task(L, 1);
+ auto res = lua_toboolean(L, 2);
+
+ if (res) {
+ auto val = lua_tointeger(L, 3);
+
+ if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
+ (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
+ /* Already learned */
+ msg_info_task("<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ }
+ else if (val != 0) {
+ /* Unlearn flag */
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ }
+ }
+
+ /* Ignore errors for now, as we can do nothing about them at the moment */
+
+ return 0;
+}
+
+gint rspamd_stat_cache_redis_check(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
+ auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
+
+ if (h == NULL) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ auto *L = ctx->L;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, h);
+
+ lua_pushcclosure(L, &rspamd_stat_cache_checked, 0);
+
+ if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ /* We need to return OK every time */
+ return RSPAMD_LEARN_OK;
+}
+
+gint rspamd_stat_cache_redis_learn(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
+
+ if (rspamd_session_blocked(task->s)) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
+ g_assert(h != NULL);
+ auto *L = ctx->L;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, h);
+ lua_pushboolean(L, is_spam);
+
+ if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ /* We need to return OK every time */
+ return RSPAMD_LEARN_OK;
+}
+
+void rspamd_stat_cache_redis_close(gpointer c)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) c;
+ delete ctx;
+}