diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-01-17 21:27:24 +0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-17 21:27:24 +0600 |
commit | 01d182c86318c9776dcaa7a6c64a19c8c0783b9d (patch) | |
tree | cc0a8376c8793ff424fb134b30bf7ede36bc159c | |
parent | 2e1d00595961a2b81358164899ffee75a6cbea8a (diff) | |
parent | a00c667b4ce0ce8d2f2454787854d191f2d5931b (diff) | |
download | rspamd-01d182c86318c9776dcaa7a6c64a19c8c0783b9d.tar.gz rspamd-01d182c86318c9776dcaa7a6c64a19c8c0783b9d.zip |
Merge pull request #4774 from rspamd/vstakhov-redis-cache-rework
Rewrite redis_cache logic in statistics
-rw-r--r-- | lualib/lua_bayes_redis.lua | 98 | ||||
-rw-r--r-- | lualib/redis_scripts/bayes_cache_check.lua | 20 | ||||
-rw-r--r-- | lualib/redis_scripts/bayes_cache_learn.lua | 61 | ||||
-rw-r--r-- | src/libstat/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 4 | ||||
-rw-r--r-- | src/libstat/learn_cache/redis_cache.c | 535 | ||||
-rw-r--r-- | src/libstat/learn_cache/redis_cache.cxx | 315 |
7 files changed, 488 insertions, 547 deletions
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index 5ad5c3514..753399705 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -20,6 +20,7 @@ local exports = {} local lua_redis = require "lua_redis" local logger = require "rspamd_logger" local lua_util = require "lua_util" +local ucl = require "ucl" local N = "bayes" @@ -54,24 +55,19 @@ local function gen_learn_functor(redis_params, learn_script_id) if maybe_text_tokens then lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = false, key = expanded_key }, + { task = task, is_write = true, key = expanded_key }, learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) else lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = false, key = expanded_key }, + { task = task, is_write = true, key = expanded_key }, learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) end end end ---- ---- Init bayes classifier ---- @param classifier_ucl ucl of the classifier config ---- @param statfile_ucl ucl of the statfile config ---- @return a pair of (classify_functor, learn_functor) or `nil` in case of error -exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb) +local function load_redis_params(classifier_ucl, statfile_ucl) local redis_params -- Try load from statfile options @@ -108,6 +104,22 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, return nil end + return redis_params +end + +--- +--- Init bayes classifier +--- @param classifier_ucl ucl of the classifier config +--- @param statfile_ucl ucl of the statfile config +--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error +exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb) + + local redis_params = load_redis_params(classifier_ucl, statfile_ucl) + + if not redis_params then + return nil + end + local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params) local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params) local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params) @@ -125,7 +137,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _) local function stat_redis_cb(err, data) - -- TODO: write this function lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data) if err then @@ -161,4 +172,73 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id) end +local function gen_cache_check_functor(redis_params, check_script_id, conf) + local packed_conf = ucl.to_format(conf, 'msgpack') + return function(task, cache_id, callback) + + local function classify_redis_cb(err, data) + lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data)) + if err then + callback(task, false, err) + else + if type(data) == 'number' then + callback(task, true, data) + else + callback(task, false, 'not found') + end + end + end + + lua_util.debugm(N, task, 'checking cache: %s', cache_id) + lua_redis.exec_redis_script(check_script_id, + { task = task, is_write = false, key = cache_id }, + classify_redis_cb, { cache_id, packed_conf }) + end +end + +local function gen_cache_learn_functor(redis_params, learn_script_id, conf) + local packed_conf = ucl.to_format(conf, 'msgpack') + return function(task, cache_id, is_spam) + local function learn_redis_cb(err, data) + lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data) + end + + lua_util.debugm(N, task, 'try to learn cache: %s', cache_id) + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = true, key = cache_id }, + learn_redis_cb, + { cache_id, is_spam and "1" or "0", packed_conf }) + + end +end + +exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl) + local redis_params = load_redis_params(classifier_ucl, statfile_ucl) + + if not redis_params then + return nil + end + + local default_conf = { + cache_prefix = "learned_ids", + cache_max_elt = 10000, -- Maximum number of elements in the cache key + cache_max_keys = 5, -- Maximum number of keys in the cache + cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value) + } + + local conf = lua_util.override_defaults(default_conf, classifier_ucl) + -- Clean all not known configurations + for k, _ in pairs(conf) do + if default_conf[k] == nil then + conf[k] = nil + end + end + + local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params) + local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params) + + return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params, + learn_script_id, conf) +end + return exports diff --git a/lualib/redis_scripts/bayes_cache_check.lua b/lualib/redis_scripts/bayes_cache_check.lua new file mode 100644 index 000000000..f1ffc2b84 --- /dev/null +++ b/lualib/redis_scripts/bayes_cache_check.lua @@ -0,0 +1,20 @@ +-- Lua script to perform cache checking for bayes classification +-- This script accepts the following parameters: +-- key1 - cache id +-- key2 - configuration table in message pack + +local cache_id = KEYS[1] +local conf = cmsgpack.unpack(KEYS[2]) +cache_id = string.sub(cache_id, 1, conf.cache_elt_len) + +-- Try each prefix that is in Redis +for i = 0, conf.cache_max_keys do + local prefix = conf.cache_prefix .. string.rep("X", i) + local have = redis.call('HGET', prefix, cache_id) + + if have then + return tonumber(have) + end +end + +return nil diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua new file mode 100644 index 000000000..8811f3c33 --- /dev/null +++ b/lualib/redis_scripts/bayes_cache_learn.lua @@ -0,0 +1,61 @@ +-- Lua script to perform cache checking for bayes classification +-- This script accepts the following parameters: +-- key1 - cache id +-- key3 - is spam (1 or 0) +-- key3 - configuration table in message pack + +local cache_id = KEYS[1] +local is_spam = KEYS[2] +local conf = cmsgpack.unpack(KEYS[3]) +cache_id = string.sub(cache_id, 1, conf.cache_elt_len) + +-- Try each prefix that is in Redis (as some other instance might have set it) +for i = 0, conf.cache_max_keys do + local prefix = conf.cache_prefix .. string.rep("X", i) + local have = redis.call('HGET', prefix, cache_id) + + if have then + -- Already in cache + return false + end +end + +local added = false +local lim = conf.cache_max_elt +for i = 0, conf.cache_max_keys do + if not added then + local prefix = conf.cache_prefix .. string.rep("X", i) + local count = redis.call('HLEN', prefix) + + if count < lim then + -- We can add it to this prefix + redis.call('HSET', prefix, cache_id, is_spam) + added = true + end + end +end + +if not added then + -- Need to expire some keys + local expired = false + for i = 0, conf.cache_max_keys do + local prefix = conf.cache_prefix .. string.rep("X", i) + local exists = redis.call('EXISTS', prefix) + + if exists then + if expired then + redis.call('DEL', prefix) + redis.call('HSET', prefix, cache_id, is_spam) + + -- Do not expire anything else + expired = true + elseif i > 0 then + -- Move key to a shorter prefix, so we will rotate them eventually from lower to upper + local new_prefix = conf.cache_prefix .. string.rep("X", i - 1) + redis.call('RENAME', prefix, new_prefix) + end + end + end +end + +return true
\ No newline at end of file diff --git a/src/libstat/CMakeLists.txt b/src/libstat/CMakeLists.txt index 4866d2433..64d572a57 100644 --- a/src/libstat/CMakeLists.txt +++ b/src/libstat/CMakeLists.txt @@ -15,7 +15,7 @@ SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c - ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.c) + ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx) SET(RSPAMD_STAT ${LIBSTATSRC} ${TOKENIZERSSRC} diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 94576224d..375fa6c9b 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2024 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1032,7 +1032,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, lua_pushcclosure(L, &rspamd_redis_learned, 1); if (lua_pcall(L, nargs, 0, err_idx) != 0) { - msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + msg_err_task("call to script failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); return FALSE; } diff --git a/src/libstat/learn_cache/redis_cache.c b/src/libstat/learn_cache/redis_cache.c deleted file mode 100644 index 0a378c8a3..000000000 --- a/src/libstat/learn_cache/redis_cache.c +++ /dev/null @@ -1,535 +0,0 @@ -/* - * Copyright 2023 Vsevolod Stakhov - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "config.h" -#include "learn_cache.h" -#include "rspamd.h" -#include "stat_api.h" -#include "stat_internal.h" -#include "cryptobox.h" -#include "ucl.h" -#include "hiredis.h" -#include "adapters/libev.h" -#include "lua/lua_common.h" -#include "libmime/message.h" - -#define REDIS_DEFAULT_TIMEOUT 0.5 -#define REDIS_STAT_TIMEOUT 30 -#define REDIS_DEFAULT_PORT 6379 -#define DEFAULT_REDIS_KEY "learned_ids" - -static const gchar *M = "redis learn cache"; - -struct rspamd_redis_cache_ctx { - lua_State *L; - struct rspamd_statfile_config *stcf; - const gchar *username; - const gchar *password; - const gchar *dbname; - const gchar *redis_object; - gdouble timeout; - gint conf_ref; -}; - -struct rspamd_redis_cache_runtime { - struct rspamd_redis_cache_ctx *ctx; - struct rspamd_task *task; - struct upstream *selected; - ev_timer timer_ev; - redisAsyncContext *redis; - gboolean has_event; -}; - -static GQuark -rspamd_stat_cache_redis_quark(void) -{ - return g_quark_from_static_string(M); -} - -static inline struct upstream_list * -rspamd_redis_get_servers(struct rspamd_redis_cache_ctx *ctx, - const gchar *what) -{ - lua_State *L = ctx->L; - struct upstream_list *res; - - lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->conf_ref); - lua_pushstring(L, what); - lua_gettable(L, -2); - res = *((struct upstream_list **) lua_touserdata(L, -1)); - lua_settop(L, 0); - - return res; -} - -static void -rspamd_redis_cache_maybe_auth(struct rspamd_redis_cache_ctx *ctx, - redisAsyncContext *redis) -{ - if (ctx->username) { - if (ctx->password) { - redisAsyncCommand(redis, NULL, NULL, "AUTH %s %s", ctx->username, ctx->password); - } - else { - msg_warn("Redis requires a password when username is supplied"); - } - } - else if (ctx->password) { - redisAsyncCommand(redis, NULL, NULL, "AUTH %s", ctx->password); - } - if (ctx->dbname) { - redisAsyncCommand(redis, NULL, NULL, "SELECT %s", ctx->dbname); - } -} - -/* Called on connection termination */ -static void -rspamd_redis_cache_fin(gpointer data) -{ - struct rspamd_redis_cache_runtime *rt = data; - redisAsyncContext *redis; - - rt->has_event = FALSE; - ev_timer_stop(rt->task->event_loop, &rt->timer_ev); - - if (rt->redis) { - redis = rt->redis; - rt->redis = NULL; - /* This calls for all callbacks pending */ - redisAsyncFree(redis); - } -} - -static void -rspamd_redis_cache_timeout(EV_P_ ev_timer *w, int revents) -{ - struct rspamd_redis_cache_runtime *rt = - (struct rspamd_redis_cache_runtime *) w->data; - struct rspamd_task *task; - - task = rt->task; - - msg_err_task("connection to redis server %s timed out", - rspamd_upstream_name(rt->selected)); - rspamd_upstream_fail(rt->selected, FALSE, "timeout"); - - if (rt->has_event) { - rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); - } -} - -/* Called when we have checked the specified message id */ -static void -rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv) -{ - struct rspamd_redis_cache_runtime *rt = priv; - redisReply *reply = r; - struct rspamd_task *task; - glong val = 0; - - task = rt->task; - - if (c->err == 0) { - if (reply) { - if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) { - val = reply->integer; - } - else if (reply->type == REDIS_REPLY_STRING) { - rspamd_strtol(reply->str, reply->len, &val); - } - else { - if (reply->type == REDIS_REPLY_ERROR) { - msg_err_task("cannot learn %s: redis error: \"%s\"", - rt->ctx->stcf->symbol, reply->str); - } - else if (reply->type != REDIS_REPLY_NIL) { - msg_err_task("bad learned type for %s: %d", - rt->ctx->stcf->symbol, reply->type); - } - - val = 0; - } - } - - if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || - (val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { - /* Already learned */ - msg_info_task("<%s> has been already " - "learned as %s, ignore it", - MESSAGE_FIELD(task, message_id), - (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); - task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; - } - else if (val != 0) { - /* Unlearn flag */ - task->flags |= RSPAMD_TASK_FLAG_UNLEARN; - } - - rspamd_upstream_ok(rt->selected); - } - else { - rspamd_upstream_fail(rt->selected, FALSE, c->errstr); - } - - if (rt->has_event) { - rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); - } -} - -/* Called when we have learned the specified message id */ -static void -rspamd_stat_cache_redis_set(redisAsyncContext *c, gpointer r, gpointer priv) -{ - struct rspamd_redis_cache_runtime *rt = priv; - struct rspamd_task *task; - - task = rt->task; - - if (c->err == 0) { - /* XXX: we ignore results here */ - rspamd_upstream_ok(rt->selected); - } - else { - rspamd_upstream_fail(rt->selected, FALSE, c->errstr); - } - - if (rt->has_event) { - rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); - } -} - -static void -rspamd_stat_cache_redis_generate_id(struct rspamd_task *task) -{ - rspamd_cryptobox_hash_state_t st; - rspamd_token_t *tok; - guint i; - guchar out[rspamd_cryptobox_HASHBYTES]; - gchar *b32out; - gchar *user = NULL; - - rspamd_cryptobox_hash_init(&st, NULL, 0); - - user = rspamd_mempool_get_variable(task->task_pool, "stat_user"); - /* Use dedicated hash space for per users cache */ - if (user != NULL) { - rspamd_cryptobox_hash_update(&st, user, strlen(user)); - } - - for (i = 0; i < task->tokens->len; i++) { - tok = g_ptr_array_index(task->tokens, i); - rspamd_cryptobox_hash_update(&st, (guchar *) &tok->data, - sizeof(tok->data)); - } - - rspamd_cryptobox_hash_final(&st, out); - - b32out = rspamd_mempool_alloc(task->task_pool, - sizeof(out) * 8 / 5 + 3); - i = rspamd_encode_base32_buf(out, sizeof(out), b32out, - sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT); - - if (i > 0) { - /* Zero terminate */ - b32out[i] = '\0'; - } - - rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL); -} - -gpointer -rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx, - struct rspamd_config *cfg, - struct rspamd_statfile *st, - const ucl_object_t *cf) -{ - struct rspamd_redis_cache_ctx *cache_ctx; - struct rspamd_statfile_config *stf = st->stcf; - const ucl_object_t *obj; - gboolean ret = FALSE; - lua_State *L = (lua_State *) cfg->lua_state; - gint conf_ref = -1; - - cache_ctx = g_malloc0(sizeof(*cache_ctx)); - cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT; - cache_ctx->L = L; - - /* First search in backend configuration */ - obj = ucl_object_lookup(st->classifier->cfg->opts, "backend"); - if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) { - ret = rspamd_lua_try_load_redis(L, obj, cfg, &conf_ref); - } - - /* Now try statfiles config */ - if (!ret && stf->opts) { - ret = rspamd_lua_try_load_redis(L, stf->opts, cfg, &conf_ref); - } - - /* Now try classifier config */ - if (!ret && st->classifier->cfg->opts) { - ret = rspamd_lua_try_load_redis(L, st->classifier->cfg->opts, cfg, &conf_ref); - } - - /* Now try global redis settings */ - if (!ret) { - obj = ucl_object_lookup(cfg->cfg_ucl_obj, "redis"); - - if (obj) { - const ucl_object_t *specific_obj; - - specific_obj = ucl_object_lookup(obj, "statistics"); - - if (specific_obj) { - ret = rspamd_lua_try_load_redis(L, - specific_obj, cfg, &conf_ref); - } - else { - ret = rspamd_lua_try_load_redis(L, - obj, cfg, &conf_ref); - } - } - } - - if (!ret) { - msg_err_config("cannot init redis cache for %s", stf->symbol); - g_free(cache_ctx); - return NULL; - } - - obj = ucl_object_lookup(st->classifier->cfg->opts, "cache_key"); - - if (obj) { - cache_ctx->redis_object = ucl_object_tostring(obj); - } - else { - cache_ctx->redis_object = DEFAULT_REDIS_KEY; - } - - cache_ctx->conf_ref = conf_ref; - - /* Check some common table values */ - lua_rawgeti(L, LUA_REGISTRYINDEX, conf_ref); - - lua_pushstring(L, "timeout"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TNUMBER) { - cache_ctx->timeout = lua_tonumber(L, -1); - } - lua_pop(L, 1); - - lua_pushstring(L, "db"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) { - cache_ctx->dbname = rspamd_mempool_strdup(cfg->cfg_pool, - lua_tostring(L, -1)); - } - lua_pop(L, 1); - - lua_pushstring(L, "username"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) { - cache_ctx->username = rspamd_mempool_strdup(cfg->cfg_pool, - lua_tostring(L, -1)); - } - lua_pop(L, 1); - - lua_pushstring(L, "password"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) { - cache_ctx->password = rspamd_mempool_strdup(cfg->cfg_pool, - lua_tostring(L, -1)); - } - lua_pop(L, 1); - - lua_settop(L, 0); - - cache_ctx->stcf = stf; - - return (gpointer) cache_ctx; -} - -gpointer -rspamd_stat_cache_redis_runtime(struct rspamd_task *task, - gpointer c, gboolean learn) -{ - struct rspamd_redis_cache_ctx *ctx = c; - struct rspamd_redis_cache_runtime *rt; - struct upstream *up; - struct upstream_list *ups; - rspamd_inet_addr_t *addr; - - g_assert(ctx != NULL); - - if (task->tokens == NULL || task->tokens->len == 0) { - return NULL; - } - - if (learn) { - ups = rspamd_redis_get_servers(ctx, "write_servers"); - - if (!ups) { - msg_err_task("no write servers defined for %s, cannot learn", - ctx->stcf->symbol); - return NULL; - } - - up = rspamd_upstream_get(ups, - RSPAMD_UPSTREAM_MASTER_SLAVE, - NULL, - 0); - } - else { - ups = rspamd_redis_get_servers(ctx, "read_servers"); - - if (!ups) { - msg_err_task("no read servers defined for %s, cannot check", - ctx->stcf->symbol); - return NULL; - } - - up = rspamd_upstream_get(ups, - RSPAMD_UPSTREAM_ROUND_ROBIN, - NULL, - 0); - } - - if (up == NULL) { - msg_err_task("no upstreams reachable"); - return NULL; - } - - rt = rspamd_mempool_alloc0(task->task_pool, sizeof(*rt)); - rt->selected = up; - rt->task = task; - rt->ctx = ctx; - - addr = rspamd_upstream_addr_next(up); - g_assert(addr != NULL); - - if (rspamd_inet_address_get_af(addr) == AF_UNIX) { - rt->redis = redisAsyncConnectUnix(rspamd_inet_address_to_string(addr)); - } - else { - rt->redis = redisAsyncConnect(rspamd_inet_address_to_string(addr), - rspamd_inet_address_get_port(addr)); - } - - if (rt->redis == NULL) { - msg_warn_task("cannot connect to redis server %s: %s", - rspamd_inet_address_to_string_pretty(addr), - strerror(errno)); - - return NULL; - } - else if (rt->redis->err != REDIS_OK) { - msg_warn_task("cannot connect to redis server %s: %s", - rspamd_inet_address_to_string_pretty(addr), - rt->redis->errstr); - redisAsyncFree(rt->redis); - rt->redis = NULL; - - return NULL; - } - - redisLibevAttach(task->event_loop, rt->redis); - - /* Now check stats */ - rt->timer_ev.data = rt; - ev_timer_init(&rt->timer_ev, rspamd_redis_cache_timeout, - rt->ctx->timeout, 0.0); - rspamd_redis_cache_maybe_auth(ctx, rt->redis); - - if (!learn) { - rspamd_stat_cache_redis_generate_id(task); - } - - return rt; -} - -gint rspamd_stat_cache_redis_check(struct rspamd_task *task, - gboolean is_spam, - gpointer runtime) -{ - struct rspamd_redis_cache_runtime *rt = runtime; - gchar *h; - - if (rspamd_session_blocked(task->s)) { - return RSPAMD_LEARN_IGNORE; - } - - h = rspamd_mempool_get_variable(task->task_pool, "words_hash"); - - if (h == NULL) { - return RSPAMD_LEARN_IGNORE; - } - - if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_get, rt, - "HGET %s %s", - rt->ctx->redis_object, h) == REDIS_OK) { - rspamd_session_add_event(task->s, - rspamd_redis_cache_fin, - rt, - M); - ev_timer_start(rt->task->event_loop, &rt->timer_ev); - rt->has_event = TRUE; - } - - /* We need to return OK every time */ - return RSPAMD_LEARN_OK; -} - -gint rspamd_stat_cache_redis_learn(struct rspamd_task *task, - gboolean is_spam, - gpointer runtime) -{ - struct rspamd_redis_cache_runtime *rt = runtime; - gchar *h; - gint flag; - - if (rt == NULL || rt->ctx == NULL || rspamd_session_blocked(task->s)) { - return RSPAMD_LEARN_IGNORE; - } - - h = rspamd_mempool_get_variable(task->task_pool, "words_hash"); - g_assert(h != NULL); - - flag = (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? 1 : -1; - - if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_set, rt, - "HSET %s %s %d", - rt->ctx->redis_object, h, flag) == REDIS_OK) { - rspamd_session_add_event(task->s, - rspamd_redis_cache_fin, rt, M); - ev_timer_start(rt->task->event_loop, &rt->timer_ev); - rt->has_event = TRUE; - } - - /* We need to return OK every time */ - return RSPAMD_LEARN_OK; -} - -void rspamd_stat_cache_redis_close(gpointer c) -{ - struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *) c; - lua_State *L; - - L = ctx->L; - - if (ctx->conf_ref) { - luaL_unref(L, LUA_REGISTRYINDEX, ctx->conf_ref); - } - - g_free(ctx); -} diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx new file mode 100644 index 000000000..b67df70a2 --- /dev/null +++ b/src/libstat/learn_cache/redis_cache.cxx @@ -0,0 +1,315 @@ +/* + * Copyright 2024 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +// Include early to avoid `extern "C"` issues +#include "lua/lua_common.h" +#include "learn_cache.h" +#include "rspamd.h" +#include "stat_api.h" +#include "stat_internal.h" +#include "cryptobox.h" +#include "ucl.h" +#include "libmime/message.h" + +#define DEFAULT_REDIS_KEY "learned_ids" + +struct rspamd_redis_cache_ctx { + lua_State *L; + struct rspamd_statfile_config *stcf; + std::string redis_object = DEFAULT_REDIS_KEY; + int check_ref = -1; + int learn_ref = -1; + + rspamd_redis_cache_ctx() = delete; + explicit rspamd_redis_cache_ctx(lua_State *L) + : L(L) + { + } + + ~rspamd_redis_cache_ctx() + { + if (check_ref != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, check_ref); + } + + if (learn_ref != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, learn_ref); + } + } +}; + +#if 0 +/* Called when we have checked the specified message id */ +static void +rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv) +{ + struct rspamd_redis_cache_runtime *rt = priv; + redisReply *reply = r; + struct rspamd_task *task; + glong val = 0; + + task = rt->task; + + if (c->err == 0) { + if (reply) { + if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) { + val = reply->integer; + } + else if (reply->type == REDIS_REPLY_STRING) { + rspamd_strtol(reply->str, reply->len, &val); + } + else { + if (reply->type == REDIS_REPLY_ERROR) { + msg_err_task("cannot learn %s: redis error: \"%s\"", + rt->ctx->stcf->symbol, reply->str); + } + else if (reply->type != REDIS_REPLY_NIL) { + msg_err_task("bad learned type for %s: %d", + rt->ctx->stcf->symbol, reply->type); + } + + val = 0; + } + } + + if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || + (val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { + /* Already learned */ + msg_info_task("<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + } + else if (val != 0) { + /* Unlearn flag */ + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + } + + rspamd_upstream_ok(rt->selected); + } + else { + rspamd_upstream_fail(rt->selected, FALSE, c->errstr); + } + + if (rt->has_event) { + rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); + } +} +#endif + +static void +rspamd_stat_cache_redis_generate_id(struct rspamd_task *task) +{ + rspamd_cryptobox_hash_state_t st; + rspamd_cryptobox_hash_init(&st, NULL, 0); + + const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user"); + /* Use dedicated hash space for per users cache */ + if (user != NULL) { + rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user)); + } + + for (auto i = 0; i < task->tokens->len; i++) { + const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i); + rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data, + sizeof(tok->data)); + } + + guchar out[rspamd_cryptobox_HASHBYTES]; + rspamd_cryptobox_hash_final(&st, out); + + auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool, + sizeof(out) * 8 / 5 + 3, char); + auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out, + sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT); + + if (out_sz > 0) { + /* Zero terminate */ + b32out[out_sz] = '\0'; + rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL); + } +} + +gpointer +rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st, + const ucl_object_t *cf) +{ + std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg)); + + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + lua_settop(L, 0); + + lua_pushcfunction(L, &rspamd_lua_traceback); + auto err_idx = lua_gettop(L); + + /* Obtain function */ + if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) { + msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache"); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* Push arguments */ + ucl_object_push_lua(L, st->classifier->cfg->opts, false); + ucl_object_push_lua(L, st->stcf->opts, false); + + if (lua_pcall(L, 2, 2, err_idx) != 0) { + msg_err("call to lua_bayes_init_cache " + "script failed: %s", + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* + * Results are in the stack: + * top - 1 - check function (idx = -2) + * top - learn function (idx = -1) + */ + lua_pushvalue(L, -2); + cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushvalue(L, -1); + cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_settop(L, err_idx - 1); + + return (gpointer) cache_ctx.release(); +} + +gpointer +rspamd_stat_cache_redis_runtime(struct rspamd_task *task, + gpointer c, gboolean learn) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) c; + + if (task->tokens == NULL || task->tokens->len == 0) { + return NULL; + } + + if (!learn) { + /* On check, we produce words_hash variable, on learn it is guaranteed to be set */ + rspamd_stat_cache_redis_generate_id(task); + } + + return (void *) ctx; +} + +static gint +rspamd_stat_cache_checked(lua_State *L) +{ + auto *task = lua_check_task(L, 1); + auto res = lua_toboolean(L, 2); + + if (res) { + auto val = lua_tointeger(L, 3); + + if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || + (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { + /* Already learned */ + msg_info_task("<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + } + else if (val != 0) { + /* Unlearn flag */ + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + } + } + + /* Ignore errors for now, as we can do nothing about them at the moment */ + + return 0; +} + +gint rspamd_stat_cache_redis_check(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) runtime; + auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash"); + + if (h == NULL) { + return RSPAMD_LEARN_IGNORE; + } + + auto *L = ctx->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref); + rspamd_lua_task_push(L, task); + lua_pushstring(L, h); + + lua_pushcclosure(L, &rspamd_stat_cache_checked, 0); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return RSPAMD_LEARN_IGNORE; + } + + /* We need to return OK every time */ + return RSPAMD_LEARN_OK; +} + +gint rspamd_stat_cache_redis_learn(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) runtime; + + if (rspamd_session_blocked(task->s)) { + return RSPAMD_LEARN_IGNORE; + } + + auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash"); + g_assert(h != NULL); + auto *L = ctx->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref); + rspamd_lua_task_push(L, task); + lua_pushstring(L, h); + lua_pushboolean(L, is_spam); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return RSPAMD_LEARN_IGNORE; + } + + /* We need to return OK every time */ + return RSPAMD_LEARN_OK; +} + +void rspamd_stat_cache_redis_close(gpointer c) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) c; + delete ctx; +} |