diff options
Diffstat (limited to 'src/libstat/backends/redis_backend.cxx')
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 161 |
1 files changed, 131 insertions, 30 deletions
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 46b27cb15..46f88de19 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -63,32 +63,58 @@ struct redis_stat_runtime { struct redis_stat_ctx *ctx; struct rspamd_task *task; struct rspamd_statfile_config *stcf; - GPtrArray *tokens; + GPtrArray *tokens = nullptr; const char *redis_object_expanded; std::uint64_t learned = 0; int id; std::vector<std::pair<int, T>> *results = nullptr; + bool need_redis_call = true; using result_type = std::vector<std::pair<int, T>>; - explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded) - : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded) +private: + /* Called on connection termination */ + static void rt_dtor(gpointer data) { + auto *rt = REDIS_RUNTIME(data); + + delete rt; } - void init() + /* Avoid occasional deletion */ + ~redis_stat_runtime() { + if (tokens) { + g_ptr_array_unref(tokens); + } + + delete results; } - void set_results(std::vector<std::pair<int, T>> *_results) +public: + explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded) + : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded) { - results = _results; + rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this); } - ~redis_stat_runtime() + static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded, + bool is_spam) -> std::optional<redis_stat_runtime<T> *> { - g_ptr_array_unref(tokens); - delete results; + auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str()); + + if (res) { + return reinterpret_cast<redis_stat_runtime<T> *>(res); + } + else { + return std::nullopt; + } + } + + void set_results(std::vector<std::pair<int, T>> *results) + { + this->results = results; } /* Propagate results from internal representation to the tokens array */ @@ -104,6 +130,15 @@ struct redis_stat_runtime { tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx); tok->values[id] = val; } + + return true; + } + + auto save_in_mempool(bool is_spam) const + { + auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + /* We do not set destructor for the variable, as it should be already added on creation */ + rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr); } }; @@ -1095,16 +1130,6 @@ rspamd_redis_async_stat_fin(struct rspamd_stat_async_elt *elt, gpointer d) #endif -/* Called on connection termination */ -static void -rspamd_redis_fin(gpointer data) -{ - auto *rt = REDIS_RUNTIME(data); - - delete rt; -} - - static bool rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, const ucl_object_t *statfile_obj, @@ -1296,19 +1321,40 @@ rspamd_redis_runtime(struct rspamd_task *task, return nullptr; } - auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); - rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt); - /* Look for the cached results */ if (!learn) { - auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H"); - auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str()); + auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + object_expanded, stcf->is_spam); - if (res) { - rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res)); + if (maybe_existing) { + /* Update stcf to correspond to what we have been asked */ + maybe_existing.value()->stcf = stcf; + return maybe_existing.value(); + } + } + + /* No cached result, create new one */ + auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + + if (!learn) { + /* + * For check, we also need to create the opposite class runtime to avoid + * double call for Redis scripts. + * This runtime will be filled later. + */ + auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + object_expanded, + !stcf->is_spam); + + if (!maybe_opposite_rt) { + auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + opposite_rt->save_in_mempool(!stcf->is_spam); + opposite_rt->need_redis_call = false; } } + rt->save_in_mempool(stcf->is_spam); + return rt; } @@ -1385,8 +1431,65 @@ rspamd_redis_classified(lua_State *L) bool result = lua_toboolean(L, 2); if (result) { + /* Indexes: + * 3 - learned_ham (int) + * 4 - learned_spam (int) + * 5 - ham_tokens (pair<int, int>) + * 6 - spam_tokens (pair<int, int>) + */ + + /* + * We need to fill our runtime AND the opposite runtime + */ + auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) { + rt->learned = learned; + redis_stat_runtime<float>::result_type *res; + + res = new redis_stat_runtime<float>::result_type(lua_objlen(L, tokens_pos)); + + for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) { + lua_rawgeti(L, -1, 1); + auto idx = lua_tointeger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 2); + auto value = lua_tonumber(L, -1); + lua_pop(L, 1); + + res->emplace_back(idx, value); + } + + rt->set_results(res); + }; + + auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + !rt->stcf->is_spam); + + if (!opposite_rt_maybe) { + msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); + + return 0; + } + + if (rt->stcf->is_spam) { + filler_func(rt, L, lua_tointeger(L, 4), 6); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + } + else { + filler_func(rt, L, lua_tointeger(L, 3), 5); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); + } + + /* Process all tokens */ + g_assert(rt->tokens != nullptr); + rt->process_tokens(rt->tokens); + opposite_rt_maybe.value()->process_tokens(rt->tokens); } else { + /* Error message is on index 3 */ + msg_err_task("cannot classify task: %s", + lua_tostring(L, 3)); } return 0; @@ -1408,9 +1511,8 @@ rspamd_redis_process_tokens(struct rspamd_task *task, return FALSE; } - if (rt->results) { - /* No need to do anything, we have results ready */ - rt->process_tokens(tokens); + if (!rt->need_redis_call) { + /* No need to do anything, as it is already done in the opposite class processing */ return TRUE; } @@ -1440,7 +1542,6 @@ rspamd_redis_process_tokens(struct rspamd_task *task, lua_pushstring(L, cookie); lua_pushcclosure(L, &rspamd_redis_classified, 1); - if (lua_pcall(L, 6, 0, err_idx) != 0) { msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); |