summaryrefslogtreecommitdiffstats
path: root/src/libstat
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2023-12-06 14:46:45 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2023-12-06 14:46:45 +0000
commit7ff93147757a3c491a0dba20558fa54eb97b48b0 (patch)
treed0b31b7954160829dadfbedfd5fdb03a77bd9273 /src/libstat
parent18c4390ea090d7259255b0742eb1ddf4152de663 (diff)
downloadrspamd-7ff93147757a3c491a0dba20558fa54eb97b48b0.tar.gz
rspamd-7ff93147757a3c491a0dba20558fa54eb97b48b0.zip
[Project] Rework stat runtime
Diffstat (limited to 'src/libstat')
-rw-r--r--src/libstat/backends/redis_backend.cxx177
1 files changed, 90 insertions, 87 deletions
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index 973e60671..46b27cb15 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -19,6 +19,11 @@
#include "stat_internal.h"
#include "upstream.h"
#include "libserver/mempool_vars_internal.h"
+#include "fmt/core.h"
+
+#include <string>
+#include <cstdint>
+#include <vector>
#define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \
rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \
@@ -28,7 +33,7 @@
INIT_LOG_MODULE(stat_redis)
#define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p))
-#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime *>(p))
+#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p))
#define REDIS_DEFAULT_OBJECT "%s%l"
#define REDIS_DEFAULT_USERS_OBJECT "%s%l%r"
#define REDIS_DEFAULT_TIMEOUT 0.5
@@ -38,31 +43,68 @@ INIT_LOG_MODULE(stat_redis)
struct redis_stat_ctx {
lua_State *L;
struct rspamd_statfile_config *stcf;
- gint conf_ref;
struct rspamd_stat_async_elt *stat_elt;
- const char *redis_object;
- gboolean enable_users;
- gboolean store_tokens;
- gboolean new_schema;
- gboolean enable_signatures;
- guint expiry;
- guint max_users;
- gint cbref_user;
-
- gint cbref_classify;
- gint cbref_learn;
+ const char *redis_object = REDIS_DEFAULT_OBJECT;
+ bool enable_users = false;
+ bool store_tokens = false;
+ bool enable_signatures = false;
+ unsigned expiry;
+ unsigned max_users = REDIS_MAX_USERS;
+ int cbref_user = -1;
+
+ int cbref_classify = -1;
+ int cbref_learn = -1;
+ int conf_ref = -1;
};
+template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
struct redis_stat_runtime {
struct redis_stat_ctx *ctx;
struct rspamd_task *task;
struct rspamd_statfile_config *stcf;
GPtrArray *tokens;
- gchar *redis_object_expanded;
- guint64 learned;
- gint id;
- GError *err;
+ const char *redis_object_expanded;
+ std::uint64_t learned = 0;
+ int id;
+ std::vector<std::pair<int, T>> *results = nullptr;
+
+ using result_type = std::vector<std::pair<int, T>>;
+
+ explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+ : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+ {
+ }
+
+ void init()
+ {
+ }
+
+ void set_results(std::vector<std::pair<int, T>> *_results)
+ {
+ results = _results;
+ }
+
+ ~redis_stat_runtime()
+ {
+ g_ptr_array_unref(tokens);
+ delete results;
+ }
+
+ /* Propagate results from internal representation to the tokens array */
+ auto process_tokens(GPtrArray *tokens) const -> bool
+ {
+ rspamd_token_t *tok;
+
+ if (!results) {
+ return false;
+ }
+
+ for (auto [idx, val]: *results) {
+ tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx);
+ tok->values[id] = val;
+ }
+ }
};
/* Used to get statistics from redis */
@@ -217,14 +259,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern,
/* Label miss is OK */
break;
case 's':
- if (ctx->new_schema) {
- tlen += sizeof("RS") - 1;
- }
- else {
- if (stcf->symbol) {
- tlen += strlen(stcf->symbol);
- }
- }
+ tlen += sizeof("RS") - 1;
break;
default:
state = just_char;
@@ -306,14 +341,7 @@ gsize rspamd_redis_expand_object(const gchar *pattern,
}
break;
case 's':
- if (ctx->new_schema) {
- d += rspamd_strlcpy(d, "RS", end - d);
- }
- else {
- if (stcf->symbol) {
- d += rspamd_strlcpy(d, stcf->symbol, end - d);
- }
- }
+ d += rspamd_strlcpy(d, "RS", end - d);
break;
default:
state = just_char;
@@ -1071,15 +1099,9 @@ rspamd_redis_async_stat_fin(struct rspamd_stat_async_elt *elt, gpointer d)
static void
rspamd_redis_fin(gpointer data)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(data);
-
- if (rt->err) {
- g_error_free(rt->err);
- }
+ auto *rt = REDIS_RUNTIME(data);
- if (rt->tokens) {
- g_ptr_array_unref(rt->tokens);
- }
+ delete rt;
}
@@ -1260,7 +1282,6 @@ rspamd_redis_runtime(struct rspamd_task *task,
gboolean learn, gpointer c, gint _id)
{
struct redis_stat_ctx *ctx = REDIS_CTX(c);
- struct redis_stat_runtime *rt;
char *object_expanded = nullptr;
g_assert(ctx != nullptr);
@@ -1275,16 +1296,18 @@ rspamd_redis_runtime(struct rspamd_task *task,
return nullptr;
}
- /* Look for the cached results */
-
+ auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
- rt = (struct redis_stat_runtime *) rspamd_mempool_alloc0(task->task_pool, sizeof(*rt));
- rt->task = task;
- rt->ctx = ctx;
- rt->redis_object_expanded = object_expanded;
- rt->stcf = stcf;
+ /* Look for the cached results */
+ if (!learn) {
+ auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H");
+ auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
- rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
+ if (res) {
+ rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res));
+ }
+ }
return rt;
}
@@ -1348,9 +1371,9 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize
static gint
rspamd_redis_classified(lua_State *L)
{
- const gchar *cookie = lua_tostring(L, lua_upvalueindex(1));
- struct rspamd_task *task = lua_check_task(L, 1);
- struct redis_stat_runtime *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+ const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+ auto *task = lua_check_task(L, 1);
+ auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
/* TODO: write it */
if (rt == nullptr) {
@@ -1374,8 +1397,8 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
GPtrArray *tokens,
gint id, gpointer p)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(p);
- lua_State *L = rt->ctx->L;
+ auto *rt = REDIS_RUNTIME(p);
+ auto *L = rt->ctx->L;
if (rspamd_session_blocked(task->s)) {
return FALSE;
@@ -1385,7 +1408,12 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
return FALSE;
}
- /* TODO: check if we have tokens for that particular id for this class */
+ if (rt->results) {
+ /* No need to do anything, we have results ready */
+ rt->process_tokens(tokens);
+
+ return TRUE;
+ }
gsize tokens_len;
gchar *tokens_buf = rspamd_redis_serialize_tokens(task, tokens, &tokens_len);
@@ -1429,19 +1457,6 @@ gboolean
rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
gpointer ctx)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
-
- if (rt->err) {
- msg_info_task("cannot retrieve stat tokens from Redis: %e", rt->err);
- g_error_free(rt->err);
- rt->err = nullptr;
- rspamd_redis_fin(rt);
-
- return FALSE;
- }
-
- rspamd_redis_fin(rt);
-
return TRUE;
}
@@ -1449,7 +1464,7 @@ gboolean
rspamd_redis_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
gint id, gpointer p)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(p);
+ auto *rt = REDIS_RUNTIME(p);
/* TODO: write learn function */
@@ -1461,18 +1476,6 @@ gboolean
rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
gpointer ctx, GError **err)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
-
- if (rt->err) {
- g_propagate_error(err, rt->err);
- rt->err = nullptr;
- rspamd_redis_fin(rt);
-
- return FALSE;
- }
-
- rspamd_redis_fin(rt);
-
return TRUE;
}
@@ -1480,7 +1483,7 @@ gulong
rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime,
gpointer ctx)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+ auto *rt = REDIS_RUNTIME(runtime);
return rt->learned;
}
@@ -1489,7 +1492,7 @@ gulong
rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime,
gpointer ctx)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+ auto *rt = REDIS_RUNTIME(runtime);
/* XXX: may cause races */
return rt->learned + 1;
@@ -1499,7 +1502,7 @@ gulong
rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime,
gpointer ctx)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+ auto *rt = REDIS_RUNTIME(runtime);
/* XXX: may cause races */
return rt->learned + 1;
@@ -1509,7 +1512,7 @@ gulong
rspamd_redis_learns(struct rspamd_task *task, gpointer runtime,
gpointer ctx)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+ auto *rt = REDIS_RUNTIME(runtime);
return rt->learned;
}
@@ -1518,7 +1521,7 @@ ucl_object_t *
rspamd_redis_get_stat(gpointer runtime,
gpointer ctx)
{
- struct redis_stat_runtime *rt = REDIS_RUNTIME(runtime);
+ auto *rt = REDIS_RUNTIME(runtime);
struct rspamd_redis_stat_elt *st;
redisAsyncContext *redis;