diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-12-06 14:46:45 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-12-06 14:46:45 +0000 |
commit | 7ff93147757a3c491a0dba20558fa54eb97b48b0 (patch) | |
tree | d0b31b7954160829dadfbedfd5fdb03a77bd9273 /src/libstat | |
parent | 18c4390ea090d7259255b0742eb1ddf4152de663 (diff) | |
download | rspamd-7ff93147757a3c491a0dba20558fa54eb97b48b0.tar.gz rspamd-7ff93147757a3c491a0dba20558fa54eb97b48b0.zip |
[Project] Rework stat runtime
Diffstat (limited to 'src/libstat')
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 177 |
1 files changed, 90 insertions, 87 deletions
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 973e60671..46b27cb15 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -19,6 +19,11 @@ #include "stat_internal.h" #include "upstream.h" #include "libserver/mempool_vars_internal.h" +#include "fmt/core.h" + +#include <string> +#include <cstdint> +#include <vector> #define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \ rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \ @@ -28,7 +33,7 @@ INIT_LOG_MODULE(stat_redis) #define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p)) -#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime *>(p)) +#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p)) #define REDIS_DEFAULT_OBJECT "%s%l" #define REDIS_DEFAULT_USERS_OBJECT "%s%l%r" #define REDIS_DEFAULT_TIMEOUT 0.5 @@ -38,31 +43,68 @@ INIT_LOG_MODULE(stat_redis) struct redis_stat_ctx { lua_State *L; struct rspamd_statfile_config *stcf; - gint conf_ref; struct rspamd_stat_async_elt *stat_elt; - const char *redis_object; - gboolean enable_users; - gboolean store_tokens; - gboolean new_schema; - gboolean enable_signatures; - guint expiry; - guint max_users; - gint cbref_user; - - gint cbref_classify; - gint cbref_learn; + const char *redis_object = REDIS_DEFAULT_OBJECT; + bool enable_users = false; + bool store_tokens = false; + bool enable_signatures = false; + unsigned expiry; + unsigned max_users = REDIS_MAX_USERS; + int cbref_user = -1; + + int cbref_classify = -1; + int cbref_learn = -1; + int conf_ref = -1; }; +template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true> struct redis_stat_runtime { struct redis_stat_ctx *ctx; struct rspamd_task *task; struct rspamd_statfile_config *stcf; GPtrArray *tokens; - gchar *redis_object_expanded; - guint64 learned; - gint id; - GError *err; + const char *redis_object_expanded; + std::uint64_t learned = 0; + int id; + std::vector<std::pair<int, T>> *results = nullptr; + + using result_type = std::vector<std::pair<int, T>>; + + explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded) + : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded) + { + } + + void init() + { + } + + void set_results(std::vector<std::pair<int, T>> *_results) + { + results = _results; + } + + ~redis_stat_runtime() + { + g_ptr_array_unref(tokens); + delete results; + } + + /* Propagate results from internal representation to the tokens array */ + auto process_tokens(GPtrArray *tokens) const -> bool + { + rspamd_token_t *tok; + + if (!results) { + return false; + } + + for (auto [idx, val]: *results) { + tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx); + tok->values[id] = val; + } + } }; /* Used to get statistics from redis */ @@ -217,14 +259,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern, /* Label miss is OK */ break; case 's': - if (ctx->new_schema) { - tlen += sizeof("RS") - 1; - } - else { - if (stcf->symbol) { - tlen += strlen(stcf->symbol); - } - } + tlen += sizeof("RS") - 1; break; default: state = just_char; @@ -306,14 +341,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern, } break; case 's': - if (ctx->new_schema) { - d += rspamd_strlcpy(d, "RS", end - d); - } - else { - if (stcf->symbol) { - d += rspamd_strlcpy(d, stcf->symbol, end - d); - } - } + d += rspamd_strlcpy(d, "RS", end - d); break; default: state = just_char; @@ -1071,15 +1099,9 @@ rspamd_redis_async_stat_fin(struct rspamd_stat_async_elt *elt, gpointer d) static void rspamd_redis_fin(gpointer data) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(data); - - if (rt->err) { - g_error_free(rt->err); - } + auto *rt = REDIS_RUNTIME(data); - if (rt->tokens) { - g_ptr_array_unref(rt->tokens); - } + delete rt; } @@ -1260,7 +1282,6 @@ rspamd_redis_runtime(struct rspamd_task *task, gboolean learn, gpointer c, gint _id) { struct redis_stat_ctx *ctx = REDIS_CTX(c); - struct redis_stat_runtime *rt; char *object_expanded = nullptr; g_assert(ctx != nullptr); @@ -1275,16 +1296,18 @@ rspamd_redis_runtime(struct rspamd_task *task, return nullptr; } - /* Look for the cached results */ - + auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt); - rt = (struct redis_stat_runtime *) rspamd_mempool_alloc0(task->task_pool, sizeof(*rt)); - rt->task = task; - rt->ctx = ctx; - rt->redis_object_expanded = object_expanded; - rt->stcf = stcf; + /* Look for the cached results */ + if (!learn) { + auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H"); + auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str()); - rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt); + if (res) { + rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res)); + } + } return rt; } @@ -1348,9 +1371,9 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize static gint rspamd_redis_classified(lua_State *L) { - const gchar *cookie = lua_tostring(L, lua_upvalueindex(1)); - struct rspamd_task *task = lua_check_task(L, 1); - struct redis_stat_runtime *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *task = lua_check_task(L, 1); + auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); /* TODO: write it */ if (rt == nullptr) { @@ -1374,8 +1397,8 @@ rspamd_redis_process_tokens(struct rspamd_task *task, GPtrArray *tokens, gint id, gpointer p) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(p); - lua_State *L = rt->ctx->L; + auto *rt = REDIS_RUNTIME(p); + auto *L = rt->ctx->L; if (rspamd_session_blocked(task->s)) { return FALSE; @@ -1385,7 +1408,12 @@ rspamd_redis_process_tokens(struct rspamd_task *task, return FALSE; } - /* TODO: check if we have tokens for that particular id for this class */ + if (rt->results) { + /* No need to do anything, we have results ready */ + rt->process_tokens(tokens); + + return TRUE; + } gsize tokens_len; gchar *tokens_buf = rspamd_redis_serialize_tokens(task, tokens, &tokens_len); @@ -1429,19 +1457,6 @@ gboolean rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime, gpointer ctx) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime); - - if (rt->err) { - msg_info_task("cannot retrieve stat tokens from Redis: %e", rt->err); - g_error_free(rt->err); - rt->err = nullptr; - rspamd_redis_fin(rt); - - return FALSE; - } - - rspamd_redis_fin(rt); - return TRUE; } @@ -1449,7 +1464,7 @@ gboolean rspamd_redis_learn_tokens(struct rspamd_task *task, GPtrArray *tokens, gint id, gpointer p) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(p); + auto *rt = REDIS_RUNTIME(p); /* TODO: write learn function */ @@ -1461,18 +1476,6 @@ gboolean rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime, gpointer ctx, GError **err) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime); - - if (rt->err) { - g_propagate_error(err, rt->err); - rt->err = nullptr; - rspamd_redis_fin(rt); - - return FALSE; - } - - rspamd_redis_fin(rt); - return TRUE; } @@ -1480,7 +1483,7 @@ gulong rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime, gpointer ctx) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime); + auto *rt = REDIS_RUNTIME(runtime); return rt->learned; } @@ -1489,7 +1492,7 @@ gulong rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime, gpointer ctx) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime); + auto *rt = REDIS_RUNTIME(runtime); /* XXX: may cause races */ return rt->learned + 1; @@ -1499,7 +1502,7 @@ gulong rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime, gpointer ctx) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime); + auto *rt = REDIS_RUNTIME(runtime); /* XXX: may cause races */ return rt->learned + 1; @@ -1509,7 +1512,7 @@ gulong rspamd_redis_learns(struct rspamd_task *task, gpointer runtime, gpointer ctx) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime); + auto *rt = REDIS_RUNTIME(runtime); return rt->learned; } @@ -1518,7 +1521,7 @@ ucl_object_t * rspamd_redis_get_stat(gpointer runtime, gpointer ctx) { - struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime); + auto *rt = REDIS_RUNTIME(runtime); struct rspamd_redis_stat_elt *st; redisAsyncContext *redis; |