From be9845618c5456f930913114ff0ad4a54fb47d0a Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 4 Dec 2023 14:34:00 +0000 Subject: [PATCH] [Project] Start to rework C part --- src/libstat/backends/redis_backend.c | 880 +++++---------------------- 1 file changed, 149 insertions(+), 731 deletions(-) diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c index 86af51f57..9263f479e 100644 --- a/src/libstat/backends/redis_backend.c +++ b/src/libstat/backends/redis_backend.c @@ -45,11 +45,7 @@ struct redis_stat_ctx { struct rspamd_statfile_config *stcf; gint conf_ref; struct rspamd_stat_async_elt *stat_elt; - const gchar *redis_object; - const gchar *username; - const gchar *password; - const gchar *dbname; - gdouble timeout; + const char *redis_object; gboolean enable_users; gboolean store_tokens; gboolean new_schema; @@ -57,29 +53,20 @@ struct redis_stat_ctx { guint expiry; guint max_users; gint cbref_user; -}; -enum rspamd_redis_connection_state { - RSPAMD_REDIS_DISCONNECTED = 0, - RSPAMD_REDIS_CONNECTED, - RSPAMD_REDIS_REQUEST_SENT, - RSPAMD_REDIS_TIMEDOUT, - RSPAMD_REDIS_TERMINATED + gint cbref_classify; + gint cbref_learn; }; + struct redis_stat_runtime { struct redis_stat_ctx *ctx; struct rspamd_task *task; - struct upstream *selected; - ev_timer timeout_event; - GArray *results; - GPtrArray *tokens; struct rspamd_statfile_config *stcf; + GArray *results; gchar *redis_object_expanded; - redisAsyncContext *redis; guint64 learned; gint id; - gboolean has_event; GError *err; }; @@ -363,31 +350,9 @@ gsize rspamd_redis_expand_object(const gchar *pattern, return tlen; } -static void -rspamd_redis_maybe_auth(struct redis_stat_ctx *ctx, redisAsyncContext *redis) -{ - if (ctx->username) { - if (ctx->password) { - redisAsyncCommand(redis, NULL, NULL, "AUTH %s %s", ctx->username, ctx->password); - } - else { - msg_warn("Redis requires a password when username is supplied"); - } - } - else if (ctx->password) { - redisAsyncCommand(redis, NULL, NULL, "AUTH %s", ctx->password); - } - if (ctx->dbname) { - redisAsyncCommand(redis, NULL, NULL, "SELECT %s", ctx->dbname); - } -} -// the `b` conversion type character is unknown to gcc -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat" -#pragma GCC diagnostic ignored "-Wformat-extra-args" -#endif +#if 0 +// Leave it unless the conversion is done, to use as a reference static rspamd_fstring_t * rspamd_redis_tokens_to_query(struct rspamd_task *task, struct redis_stat_runtime *rt, @@ -645,9 +610,7 @@ rspamd_redis_tokens_to_query(struct rspamd_task *task, return out; } -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif + static void rspamd_redis_store_stat_signature(struct rspamd_task *task, @@ -728,6 +691,8 @@ rspamd_redis_store_stat_signature(struct rspamd_task *task, rspamd_fstring_free(out); } +#endif + static void rspamd_redis_async_cbdata_cleanup(struct rspamd_redis_stat_cbdata *cbdata) { @@ -1081,7 +1046,6 @@ rspamd_redis_async_stat_cb(struct rspamd_stat_async_elt *elt, gpointer d) /* XXX: deal with timeouts maybe */ /* Get keys in redis that match our symbol */ - rspamd_redis_maybe_auth(ctx, cbdata->redis); redisAsyncCommand(cbdata->redis, rspamd_redis_stat_keys, redis_elt, "SSCAN %s_keys 0 COUNT %d", ctx->stcf->symbol, @@ -1112,346 +1076,23 @@ static void rspamd_redis_fin(gpointer data) { struct redis_stat_runtime *rt = REDIS_RUNTIME(data); - redisAsyncContext *redis; - - if (rt->has_event) { - /* Should not happen ! */ - msg_err("FIXME: this code path should not be reached!"); - rspamd_session_remove_event(rt->task->s, NULL, rt); - rt->has_event = FALSE; - } - /* Stop timeout */ - if (ev_can_stop(&rt->timeout_event)) { - ev_timer_stop(rt->task->event_loop, &rt->timeout_event); - } - - if (rt->tokens) { - g_ptr_array_unref(rt->tokens); - rt->tokens = NULL; - } - - if (rt->redis) { - redis = rt->redis; - rt->redis = NULL; - /* This calls for all callbacks pending */ - redisAsyncFree(redis); - } if (rt->err) { g_error_free(rt->err); } } -static void -rspamd_redis_timeout(EV_P_ ev_timer *w, int revents) -{ - struct redis_stat_runtime *rt = REDIS_RUNTIME(w->data); - struct rspamd_task *task; - redisAsyncContext *redis; - - task = rt->task; - - msg_err_task_check("connection to redis server %s timed out", - rspamd_upstream_name(rt->selected)); - - rspamd_upstream_fail(rt->selected, FALSE, "timeout"); - - if (rt->redis) { - redis = rt->redis; - rt->redis = NULL; - /* This calls for all callbacks pending */ - redisAsyncFree(redis); - } - - if (rt->tokens) { - g_ptr_array_unref(rt->tokens); - rt->tokens = NULL; - } - - if (!rt->err) { - g_set_error(&rt->err, rspamd_redis_stat_quark(), ETIMEDOUT, - "error getting reply from redis server %s: timeout", - rspamd_upstream_name(rt->selected)); - } - if (rt->has_event) { - rt->has_event = FALSE; - rspamd_session_remove_event(task->s, NULL, rt); - } -} - -/* Called when we have received tokens values from redis */ -static void -rspamd_redis_processed(redisAsyncContext *c, gpointer r, gpointer priv) -{ - struct redis_stat_runtime *rt = REDIS_RUNTIME(priv); - redisReply *reply = r, *elt; - struct rspamd_task *task; - rspamd_token_t *tok; - guint i, processed = 0, found = 0; - gulong val; - gdouble float_val; - - task = rt->task; - - if (c->err == 0 && rt->has_event) { - if (r != NULL) { - if (reply->type == REDIS_REPLY_ARRAY) { - - if (reply->elements == task->tokens->len) { - for (i = 0; i < reply->elements; i++) { - tok = g_ptr_array_index(task->tokens, i); - elt = reply->element[i]; - - if (G_UNLIKELY(elt->type == REDIS_REPLY_INTEGER)) { - tok->values[rt->id] = elt->integer; - found++; - } - else if (elt->type == REDIS_REPLY_STRING) { - if (rt->stcf->clcf->flags & - RSPAMD_FLAG_CLASSIFIER_INTEGER) { - rspamd_strtoul(elt->str, elt->len, &val); - tok->values[rt->id] = val; - } - else { - float_val = strtof(elt->str, NULL); - tok->values[rt->id] = float_val; - } - - found++; - } - else { - tok->values[rt->id] = 0; - } - - processed++; - } - - if (rt->stcf->is_spam) { - task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; - } - else { - task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; - } - } - else { - msg_err_task_check("got invalid length of reply vector from redis: " - "%d, expected: %d", - (gint) reply->elements, - (gint) task->tokens->len); - } - } - else { - if (reply->type == REDIS_REPLY_ERROR) { - msg_err_task_check("cannot learn %s: redis error: \"%s\"", - rt->stcf->symbol, reply->str); - } - else { - msg_err_task_check("got invalid reply from redis: %s, array expected", - rspamd_redis_type_to_string(reply->type)); - } - } - - msg_debug_stat_redis("received tokens for %s: %d processed, %d found", - rt->redis_object_expanded, processed, found); - rspamd_upstream_ok(rt->selected); - } - } - else { - msg_err_task("error getting reply from redis server %s: %s", - rspamd_upstream_name(rt->selected), c->errstr); - - if (rt->redis) { - rspamd_upstream_fail(rt->selected, FALSE, c->errstr); - } - - if (!rt->err) { - g_set_error(&rt->err, rspamd_redis_stat_quark(), c->err, - "cannot get values: error getting reply from redis server %s: %s", - rspamd_upstream_name(rt->selected), c->errstr); - } - } - if (rt->has_event) { - rt->has_event = FALSE; - rspamd_session_remove_event(task->s, NULL, rt); - } -} - -/* Called when we have connected to the redis server and got stats */ -static void -rspamd_redis_connected(redisAsyncContext *c, gpointer r, gpointer priv) -{ - struct redis_stat_runtime *rt = REDIS_RUNTIME(priv); - redisReply *reply = r; - struct rspamd_task *task; - glong val = 0; - gboolean final = TRUE; - - task = rt->task; - - if (c->err == 0 && rt->has_event) { - if (r != NULL) { - if (G_UNLIKELY(reply->type == REDIS_REPLY_INTEGER)) { - val = reply->integer; - } - else if (reply->type == REDIS_REPLY_STRING) { - rspamd_strtol(reply->str, reply->len, &val); - } - else { - if (reply->type != REDIS_REPLY_NIL) { - if (reply->type == REDIS_REPLY_ERROR) { - msg_err_task("cannot learn %s: redis error: \"%s\"", - rt->stcf->symbol, reply->str); - } - else { - msg_err_task("bad learned type for %s: %s, nil expected", - rt->stcf->symbol, - rspamd_redis_type_to_string(reply->type)); - } - } - - val = 0; - } - - if (val < 0) { - msg_warn_task("invalid number of learns for %s: %L", - rt->stcf->symbol, val); - val = 0; - } - - rt->learned = val; - msg_debug_stat_redis("connected to redis server, tokens learned for %s: %uL", - rt->redis_object_expanded, rt->learned); - rspamd_upstream_ok(rt->selected); - - /* Save learn count in mempool variable */ - gint64 *learns_cnt; - const gchar *var_name; - - if (rt->stcf->is_spam) { - var_name = RSPAMD_MEMPOOL_SPAM_LEARNS; - } - else { - var_name = RSPAMD_MEMPOOL_HAM_LEARNS; - } - - learns_cnt = rspamd_mempool_get_variable(task->task_pool, - var_name); - - if (learns_cnt) { - (*learns_cnt) += rt->learned; - } - else { - learns_cnt = rspamd_mempool_alloc(task->task_pool, - sizeof(*learns_cnt)); - *learns_cnt = rt->learned; - rspamd_mempool_set_variable(task->task_pool, - var_name, - learns_cnt, NULL); - } - - if (rt->learned >= rt->stcf->clcf->min_learns && rt->learned > 0) { - rspamd_fstring_t *query = rspamd_redis_tokens_to_query( - task, - rt, - rt->tokens, - rt->ctx->new_schema ? "HGET" : "HMGET", - rt->redis_object_expanded, FALSE, -1, - rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER); - g_assert(query != NULL); - rspamd_mempool_add_destructor(task->task_pool, - (rspamd_mempool_destruct_t) rspamd_fstring_free, query); - - int ret = redisAsyncFormattedCommand(rt->redis, - rspamd_redis_processed, rt, - query->str, query->len); - - if (ret != REDIS_OK) { - msg_err_task("call to redis failed: %s", rt->redis->errstr); - } - else { - /* Further is handled by rspamd_redis_processed */ - final = FALSE; - /* Restart timeout */ - if (ev_can_stop(&rt->timeout_event)) { - rt->timeout_event.repeat = rt->ctx->timeout; - ev_timer_again(task->event_loop, &rt->timeout_event); - } - else { - rt->timeout_event.data = rt; - ev_timer_init(&rt->timeout_event, rspamd_redis_timeout, - rt->ctx->timeout, 0.); - ev_timer_start(task->event_loop, &rt->timeout_event); - } - } - } - else { - msg_warn_task("skip obtaining bayes tokens for %s of classifier " - "%s: not enough learns %d; %d required", - rt->stcf->symbol, rt->stcf->clcf->name, - (int) rt->learned, rt->stcf->clcf->min_learns); - } - } - } - else if (rt->has_event) { - msg_err_task("error getting reply from redis server %s: %s", - rspamd_upstream_name(rt->selected), c->errstr); - rspamd_upstream_fail(rt->selected, FALSE, c->errstr); - - if (!rt->err) { - g_set_error(&rt->err, rspamd_redis_stat_quark(), c->err, - "error getting reply from redis server %s: %s", - rspamd_upstream_name(rt->selected), c->errstr); - } - } - - if (final && rt->has_event) { - rt->has_event = FALSE; - rspamd_session_remove_event(task->s, NULL, rt); - } -} - -/* Called when we have set tokens during learning */ -static void -rspamd_redis_learned(redisAsyncContext *c, gpointer r, gpointer priv) -{ - struct redis_stat_runtime *rt = REDIS_RUNTIME(priv); - struct rspamd_task *task; - - task = rt->task; - - if (c->err == 0) { - rspamd_upstream_ok(rt->selected); - } - else { - msg_err_task_check("error getting reply from redis server %s: %s", - rspamd_upstream_name(rt->selected), c->errstr); - - if (rt->redis) { - rspamd_upstream_fail(rt->selected, FALSE, c->errstr); - } - - if (!rt->err) { - g_set_error(&rt->err, rspamd_redis_stat_quark(), c->err, - "cannot get learned: error getting reply from redis server %s: %s", - rspamd_upstream_name(rt->selected), c->errstr); - } - } - - if (rt->has_event) { - rt->has_event = FALSE; - rspamd_session_remove_event(task->s, NULL, rt); - } -} -static void +static bool rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, - const ucl_object_t *obj, + const ucl_object_t *statfile_obj, + const ucl_object_t *classifier_obj, struct rspamd_config *cfg) { const gchar *lua_script; const ucl_object_t *elt, *users_enabled; - users_enabled = ucl_object_lookup_any(obj, "per_user", + users_enabled = ucl_object_lookup_any(classifier_obj, "per_user", "users_enabled", NULL); if (users_enabled != NULL) { @@ -1487,7 +1128,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, backend->cbref_user = -1; } - elt = ucl_object_lookup(obj, "prefix"); + elt = ucl_object_lookup(classifier_obj, "prefix"); if (elt == NULL || ucl_object_type(elt) != UCL_STRING) { /* Default non-users statistics */ if (backend->enable_users || backend->cbref_user != -1) { @@ -1502,7 +1143,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, backend->redis_object = ucl_object_tostring(elt); } - elt = ucl_object_lookup(obj, "store_tokens"); + elt = ucl_object_lookup(classifier_obj, "store_tokens"); if (elt) { backend->store_tokens = ucl_object_toboolean(elt); } @@ -1510,19 +1151,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, backend->store_tokens = FALSE; } - elt = ucl_object_lookup(obj, "new_schema"); - if (elt) { - backend->new_schema = ucl_object_toboolean(elt); - } - else { - backend->new_schema = FALSE; - - msg_warn_config("you are using old bayes schema for redis statistics, " - "please consider converting it to a new one " - "by using 'rspamadm configwizard statistics'"); - } - - elt = ucl_object_lookup(obj, "signatures"); + elt = ucl_object_lookup(classifier_obj, "signatures"); if (elt) { backend->enable_signatures = ucl_object_toboolean(elt); } @@ -1530,7 +1159,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, backend->enable_signatures = FALSE; } - elt = ucl_object_lookup_any(obj, "expiry", "expire", NULL); + elt = ucl_object_lookup_any(classifier_obj, "expiry", "expire", NULL); if (elt) { backend->expiry = ucl_object_toint(elt); } @@ -1538,13 +1167,53 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, backend->expiry = 0; } - elt = ucl_object_lookup(obj, "max_users"); + elt = ucl_object_lookup(classifier_obj, "max_users"); if (elt) { backend->max_users = ucl_object_toint(elt); } else { backend->max_users = REDIS_MAX_USERS; } + + lua_State *L = RSPAMD_LUA_CFG_STATE(cfg); + lua_pushcfunction(L, &rspamd_lua_traceback); + int err_idx = lua_gettop(L); + + /* Obtain function */ + if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_classifier")) { + msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_classifier"); + lua_settop(L, err_idx - 1); + + return false; + } + + /* Push arguments */ + ucl_object_push_lua(L, classifier_obj, false); + ucl_object_push_lua(L, statfile_obj, false); + + if (lua_pcall(L, 2, 2, err_idx) != 0) { + msg_err("call to lua_bayes_init_classifier " + "script failed: %s", + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + + return NULL; + } + + /* Results are in the stack: + * top - 1 - classifier function (idx = -2) + * top - learn function (idx = -1) + */ + + lua_pushvalue(L, -2); + backend->cbref_classify = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushvalue(L, -1); + backend->cbref_learn = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_settop(L, err_idx - 1); + + return true; } gpointer @@ -1561,90 +1230,18 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx, backend = g_malloc0(sizeof(*backend)); backend->L = L; - backend->timeout = REDIS_DEFAULT_TIMEOUT; backend->max_users = REDIS_MAX_USERS; - /* First search in backend configuration */ - obj = ucl_object_lookup(st->classifier->cfg->opts, "backend"); - if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) { - ret = rspamd_lua_try_load_redis(L, obj, cfg, &conf_ref); - } - - /* Now try statfiles config */ - if (!ret && stf->opts) { - ret = rspamd_lua_try_load_redis(L, stf->opts, cfg, &conf_ref); - } - - /* Now try classifier config */ - if (!ret && st->classifier->cfg->opts) { - ret = rspamd_lua_try_load_redis(L, st->classifier->cfg->opts, cfg, &conf_ref); - } - - /* Now try global redis settings */ - if (!ret) { - obj = ucl_object_lookup(cfg->cfg_ucl_obj, "redis"); - - if (obj) { - const ucl_object_t *specific_obj; - - specific_obj = ucl_object_lookup(obj, "statistics"); + backend->conf_ref = conf_ref; - if (specific_obj) { - ret = rspamd_lua_try_load_redis(L, - specific_obj, cfg, &conf_ref); - } - else { - ret = rspamd_lua_try_load_redis(L, - obj, cfg, &conf_ref); - } - } - } + lua_settop(L, 0); - if (!ret) { + if (!rspamd_redis_parse_classifier_opts(backend, st->stcf->opts, st->classifier->cfg->opts, cfg)) { msg_err_config("cannot init redis backend for %s", stf->symbol); g_free(backend); return NULL; } - backend->conf_ref = conf_ref; - - /* Check some common table values */ - lua_rawgeti(L, LUA_REGISTRYINDEX, conf_ref); - - lua_pushstring(L, "timeout"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TNUMBER) { - backend->timeout = lua_tonumber(L, -1); - } - lua_pop(L, 1); - - lua_pushstring(L, "db"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) { - backend->dbname = rspamd_mempool_strdup(cfg->cfg_pool, - lua_tostring(L, -1)); - } - lua_pop(L, 1); - - lua_pushstring(L, "username"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) { - backend->username = rspamd_mempool_strdup(cfg->cfg_pool, - lua_tostring(L, -1)); - } - lua_pop(L, 1); - - lua_pushstring(L, "password"); - lua_gettable(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) { - backend->password = rspamd_mempool_strdup(cfg->cfg_pool, - lua_tostring(L, -1)); - } - lua_pop(L, 1); - - lua_settop(L, 0); - - rspamd_redis_parse_classifier_opts(backend, st->classifier->cfg->opts, cfg); stf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; backend->stcf = stf; @@ -1661,39 +1258,6 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx, return (gpointer) backend; } -/* - * This callback is called when Redis is disconnected somehow, and the structure - * itself is usually freed by hiredis itself - */ -static void -rspamd_stat_redis_on_disconnect(const struct redisAsyncContext *ac, int status) -{ - struct redis_stat_runtime *rt = (struct redis_stat_runtime *) ac->data; - - if (ev_can_stop(&rt->timeout_event)) { - ev_timer_stop(rt->task->event_loop, &rt->timeout_event); - } - rt->redis = NULL; -} - -static void -rspamd_stat_redis_on_connect(const struct redisAsyncContext *ac, int status) -{ - struct redis_stat_runtime *rt = (struct redis_stat_runtime *) ac->data; - - - if (status == REDIS_ERR) { - /* - * We also need to reset rt->redis as it will be subsequently freed without - * calling for redis_on_disconnect callback... - */ - if (ev_can_stop(&rt->timeout_event)) { - ev_timer_stop(rt->task->event_loop, &rt->timeout_event); - } - rt->redis = NULL; - } -} - gpointer rspamd_redis_runtime(struct rspamd_task *task, struct rspamd_statfile_config *stcf, @@ -1701,46 +1265,12 @@ rspamd_redis_runtime(struct rspamd_task *task, { struct redis_stat_ctx *ctx = REDIS_CTX(c); struct redis_stat_runtime *rt; - struct upstream *up; - struct upstream_list *ups; char *object_expanded = NULL; rspamd_inet_addr_t *addr; g_assert(ctx != NULL); g_assert(stcf != NULL); - if (learn) { - ups = rspamd_redis_get_servers(ctx, "write_servers"); - - if (!ups) { - msg_err_task("no write servers defined for %s, cannot learn", - stcf->symbol); - return NULL; - } - up = rspamd_upstream_get(ups, - RSPAMD_UPSTREAM_MASTER_SLAVE, - NULL, - 0); - } - else { - ups = rspamd_redis_get_servers(ctx, "read_servers"); - - if (!ups) { - msg_err_task("no read servers defined for %s, cannot stat", - stcf->symbol); - return NULL; - } - up = rspamd_upstream_get(ups, - RSPAMD_UPSTREAM_ROUND_ROBIN, - NULL, - 0); - } - - if (up == NULL) { - msg_err_task("no upstreams reachable"); - return NULL; - } - if (rspamd_redis_expand_object(ctx->redis_object, ctx, task, &object_expanded) == 0) { msg_err_task("expansion for %s failed for symbol %s " @@ -1751,45 +1281,10 @@ rspamd_redis_runtime(struct rspamd_task *task, } rt = rspamd_mempool_alloc0(task->task_pool, sizeof(*rt)); - rt->selected = up; rt->task = task; rt->ctx = ctx; - rt->stcf = stcf; rt->redis_object_expanded = object_expanded; - addr = rspamd_upstream_addr_next(up); - g_assert(addr != NULL); - - if (rspamd_inet_address_get_af(addr) == AF_UNIX) { - rt->redis = redisAsyncConnectUnix(rspamd_inet_address_to_string(addr)); - } - else { - rt->redis = redisAsyncConnect(rspamd_inet_address_to_string(addr), - rspamd_inet_address_get_port(addr)); - } - - if (rt->redis == NULL) { - msg_warn_task("cannot connect to redis server %s: %s", - rspamd_inet_address_to_string_pretty(addr), - strerror(errno)); - return NULL; - } - else if (rt->redis->err != REDIS_OK) { - msg_warn_task("cannot connect to redis server %s: %s", - rspamd_inet_address_to_string_pretty(addr), - rt->redis->errstr); - redisAsyncFree(rt->redis); - rt->redis = NULL; - - return NULL; - } - - redisLibevAttach(task->event_loop, rt->redis); - rspamd_redis_maybe_auth(ctx, rt->redis); - rt->redis->data = rt; - redisAsyncSetDisconnectCallback(rt->redis, rspamd_stat_redis_on_disconnect); - redisAsyncSetConnectCallback(rt->redis, rspamd_stat_redis_on_connect); - rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt); return rt; @@ -1804,56 +1299,113 @@ void rspamd_redis_close(gpointer p) luaL_unref(L, LUA_REGISTRYINDEX, ctx->conf_ref); } + if (ctx->cbref_learn) { + luaL_unref(L, LUA_REGISTRYINDEX, ctx->cbref_learn); + } + + if (ctx->cbref_classify) { + luaL_unref(L, LUA_REGISTRYINDEX, ctx->cbref_classify); + } + g_free(ctx); } +/* + * Serialise stat tokens to message pack + */ +static char * +rspamd_redis_serialize_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len) +{ + /* Each token is int64_t that requires 9 bytes + 4 bytes array len + 1 byte array magic */ + gsize req_len = tokens->len * 9 + 5, i; + gchar *buf, *p; + rspamd_token_t *tok; + + buf = rspamd_mempool_alloc(task->task_pool, req_len); + p = buf; + + /* Array */ + *p++ = (gchar) 0xdd; + /* Length in big-endian (4 bytes) */ + *p++ = (gchar) ((tokens->len >> 24) & 0xff); + *p++ = (gchar) ((tokens->len >> 16) & 0xff); + *p++ = (gchar) ((tokens->len >> 8) & 0xff); + *p++ = (gchar) (tokens->len & 0xff); + + PTR_ARRAY_FOREACH(tokens, i, tok) + { + *p++ = (gchar) 0xd3; + + guint64 val = GUINT64_TO_BE(tok->data); + memcpy(p, &val, sizeof(val)); + p += sizeof(val); + } + + *ser_len = p - buf; + + return buf; +} + +static gint +rspamd_redis_classified(lua_State *L) +{ + const gchar *cookie = lua_tostring(L, lua_upvalueindex(1)); + struct rspamd_task *task = lua_check_task(L, 1); + struct redis_stat_runtime *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + /* TODO: write it */ +} + gboolean rspamd_redis_process_tokens(struct rspamd_task *task, GPtrArray *tokens, gint id, gpointer p) { struct redis_stat_runtime *rt = REDIS_RUNTIME(p); - const gchar *learned_key = "learns"; + lua_State *L = rt->ctx->L; if (rspamd_session_blocked(task->s)) { return FALSE; } - if (tokens == NULL || tokens->len == 0 || rt->redis == NULL) { + if (tokens == NULL || tokens->len == 0) { return FALSE; } - rt->id = id; - - if (rt->ctx->new_schema) { - if (rt->ctx->stcf->is_spam) { - learned_key = "learns_spam"; - } - else { - learned_key = "learns_ham"; - } - } + /* TODO: check if we have tokens for that particular id for this class */ - if (redisAsyncCommand(rt->redis, rspamd_redis_connected, rt, "HGET %s %s", - rt->redis_object_expanded, learned_key) == REDIS_OK) { + gsize tokens_len; + gchar *tokens_buf = rspamd_redis_serialize_tokens(task, tokens, &tokens_len); - rspamd_session_add_event(task->s, NULL, rt, M); - rt->has_event = TRUE; - rt->tokens = g_ptr_array_ref(tokens); + rt->id = id; - if (ev_can_stop(&rt->timeout_event)) { - rt->timeout_event.repeat = rt->ctx->timeout; - ev_timer_again(task->event_loop, &rt->timeout_event); - } - else { - rt->timeout_event.data = rt; - ev_timer_init(&rt->timeout_event, rspamd_redis_timeout, - rt->ctx->timeout, 0.); - ev_timer_start(task->event_loop, &rt->timeout_event); - } + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + rspamd_lua_task_push(L, task); + lua_pushstring(L, rt->redis_object_expanded); + lua_pushinteger(L, id); + lua_pushboolean(L, rt->stcf->is_spam); + lua_new_text(L, tokens_buf, tokens_len, false); + + /* Store rt in random cookie */ + gchar *cookie = rspamd_mempool_alloc(task->task_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(task->task_pool, cookie, rt, NULL); + /* Callback */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_classified, 1); + + + if (lua_pcall(L, 6, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return FALSE; } - return FALSE; + lua_settop(L, err_idx - 1); + return TRUE; } gboolean @@ -1881,137 +1433,9 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, GPtrArray *tokens, gint id, gpointer p) { struct redis_stat_runtime *rt = REDIS_RUNTIME(p); - rspamd_fstring_t *query; - const gchar *redis_cmd; - rspamd_token_t *tok; - gint ret; - goffset off; - const gchar *learned_key = "learns"; - - if (rspamd_session_blocked(task->s)) { - return FALSE; - } - - if (rt->ctx->new_schema) { - if (rt->ctx->stcf->is_spam) { - learned_key = "learns_spam"; - } - else { - learned_key = "learns_ham"; - } - } - - /* - * Add the current key to the set of learned keys - */ - redisAsyncCommand(rt->redis, NULL, NULL, "SADD %s_keys %s", - rt->stcf->symbol, rt->redis_object_expanded); - - if (rt->ctx->new_schema) { - redisAsyncCommand(rt->redis, NULL, NULL, "HSET %s version 2", - rt->redis_object_expanded); - } + lua_State *L = rt->ctx->L; - if (rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER) { - redis_cmd = "HINCRBY"; - } - else { - redis_cmd = "HINCRBYFLOAT"; - } - - rt->id = id; - query = rspamd_redis_tokens_to_query(task, rt, tokens, - redis_cmd, rt->redis_object_expanded, TRUE, id, - rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER); - g_assert(query != NULL); - query->len = 0; - - /* - * XXX: - * Dirty hack: we get a token and check if it's value is -1 or 1, so - * we could understand that we are learning or unlearning - */ - - tok = g_ptr_array_index(task->tokens, 0); - - if (tok->values[id] > 0) { - rspamd_printf_fstring(&query, "" - "*4\r\n" - "$7\r\n" - "HINCRBY\r\n" - "$%d\r\n" - "%s\r\n" - "$%d\r\n" - "%s\r\n" /* Learned key */ - "$1\r\n" - "1\r\n", - (gint) strlen(rt->redis_object_expanded), - rt->redis_object_expanded, - (gint) strlen(learned_key), - learned_key); - } - else { - rspamd_printf_fstring(&query, "" - "*4\r\n" - "$7\r\n" - "HINCRBY\r\n" - "$%d\r\n" - "%s\r\n" - "$%d\r\n" - "%s\r\n" /* Learned key */ - "$2\r\n" - "-1\r\n", - (gint) strlen(rt->redis_object_expanded), - rt->redis_object_expanded, - (gint) strlen(learned_key), - learned_key); - } - - ret = redisAsyncFormattedCommand(rt->redis, NULL, NULL, - query->str, query->len); - - if (ret != REDIS_OK) { - msg_err_task("call to redis failed: %s", rt->redis->errstr); - rspamd_fstring_free(query); - - return FALSE; - } - - off = query->len; - ret = rspamd_printf_fstring(&query, "*1\r\n$4\r\nEXEC\r\n"); - ret = redisAsyncFormattedCommand(rt->redis, rspamd_redis_learned, rt, - query->str + off, ret); - rspamd_mempool_add_destructor(task->task_pool, - (rspamd_mempool_destruct_t) rspamd_fstring_free, query); - - if (ret == REDIS_OK) { - - /* Add signature if needed */ - if (rt->ctx->enable_signatures) { - rspamd_redis_store_stat_signature(task, rt, tokens, - "RSIG"); - } - - rspamd_session_add_event(task->s, NULL, rt, M); - rt->has_event = TRUE; - - /* Set timeout */ - if (ev_can_stop(&rt->timeout_event)) { - rt->timeout_event.repeat = rt->ctx->timeout; - ev_timer_again(task->event_loop, &rt->timeout_event); - } - else { - rt->timeout_event.data = rt; - ev_timer_init(&rt->timeout_event, rspamd_redis_timeout, - rt->ctx->timeout, 0.); - ev_timer_start(task->event_loop, &rt->timeout_event); - } - - return TRUE; - } - else { - msg_err_task("call to redis failed: %s", rt->redis->errstr); - } + /* TODO: write learn function */ return FALSE; } @@ -2085,12 +1509,6 @@ rspamd_redis_get_stat(gpointer runtime, if (rt->ctx->stat_elt) { st = rt->ctx->stat_elt->ud; - if (rt->redis) { - redis = rt->redis; - rt->redis = NULL; - redisAsyncFree(redis); - } - if (st->stat) { return ucl_object_ref(st->stat); } -- 2.39.5