struct redis_stat_ctx *ctx;
struct rspamd_task *task;
struct rspamd_statfile_config *stcf;
- GPtrArray *tokens;
+ GPtrArray *tokens = nullptr;
const char *redis_object_expanded;
std::uint64_t learned = 0;
int id;
std::vector<std::pair<int, T>> *results = nullptr;
+ bool need_redis_call = true;
using result_type = std::vector<std::pair<int, T>>;
- explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
- : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+private:
+ /* Called on connection termination */
+ static void rt_dtor(gpointer data)
{
+ auto *rt = REDIS_RUNTIME(data);
+
+ delete rt;
}
- void init()
+ /* Avoid occasional deletion */
+ ~redis_stat_runtime()
{
+ if (tokens) {
+ g_ptr_array_unref(tokens);
+ }
+
+ delete results;
}
- void set_results(std::vector<std::pair<int, T>> *_results)
+public:
+ explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+ : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
{
- results = _results;
+ rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this);
}
- ~redis_stat_runtime()
+ static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
+ bool is_spam) -> std::optional<redis_stat_runtime<T> *>
{
- g_ptr_array_unref(tokens);
- delete results;
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
+
+ if (res) {
+ return reinterpret_cast<redis_stat_runtime<T> *>(res);
+ }
+ else {
+ return std::nullopt;
+ }
+ }
+
+ void set_results(std::vector<std::pair<int, T>> *results)
+ {
+ this->results = results;
}
/* Propagate results from internal representation to the tokens array */
tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx);
tok->values[id] = val;
}
+
+ return true;
+ }
+
+ auto save_in_mempool(bool is_spam) const
+ {
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ /* We do not set destructor for the variable, as it should be already added on creation */
+ rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
}
};
#endif
-/* Called on connection termination */
-static void
-rspamd_redis_fin(gpointer data)
-{
- auto *rt = REDIS_RUNTIME(data);
-
- delete rt;
-}
-
-
static bool
rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
const ucl_object_t *statfile_obj,
return nullptr;
}
- auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
-
/* Look for the cached results */
if (!learn) {
- auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H");
- auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
+ auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ object_expanded, stcf->is_spam);
- if (res) {
- rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res));
+ if (maybe_existing) {
+ /* Update stcf to correspond to what we have been asked */
+ maybe_existing.value()->stcf = stcf;
+ return maybe_existing.value();
+ }
+ }
+
+ /* No cached result, create new one */
+ auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+
+ if (!learn) {
+ /*
+ * For check, we also need to create the opposite class runtime to avoid
+ * double call for Redis scripts.
+ * This runtime will be filled later.
+ */
+ auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ object_expanded,
+ !stcf->is_spam);
+
+ if (!maybe_opposite_rt) {
+ auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ opposite_rt->save_in_mempool(!stcf->is_spam);
+ opposite_rt->need_redis_call = false;
}
}
+ rt->save_in_mempool(stcf->is_spam);
+
return rt;
}
bool result = lua_toboolean(L, 2);
if (result) {
+ /* Indexes:
+ * 3 - learned_ham (int)
+ * 4 - learned_spam (int)
+ * 5 - ham_tokens (pair<int, int>)
+ * 6 - spam_tokens (pair<int, int>)
+ */
+
+ /*
+ * We need to fill our runtime AND the opposite runtime
+ */
+ auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
+ rt->learned = learned;
+ redis_stat_runtime<float>::result_type *res;
+
+ res = new redis_stat_runtime<float>::result_type(lua_objlen(L, tokens_pos));
+
+ for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
+ lua_rawgeti(L, -1, 1);
+ auto idx = lua_tointeger(L, -1);
+ lua_pop(L, 1);
+
+ lua_rawgeti(L, -1, 2);
+ auto value = lua_tonumber(L, -1);
+ lua_pop(L, 1);
+
+ res->emplace_back(idx, value);
+ }
+
+ rt->set_results(res);
+ };
+
+ auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ !rt->stcf->is_spam);
+
+ if (!opposite_rt_maybe) {
+ msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+
+ return 0;
+ }
+
+ if (rt->stcf->is_spam) {
+ filler_func(rt, L, lua_tointeger(L, 4), 6);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ }
+ else {
+ filler_func(rt, L, lua_tointeger(L, 3), 5);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+ }
+
+ /* Process all tokens */
+ g_assert(rt->tokens != nullptr);
+ rt->process_tokens(rt->tokens);
+ opposite_rt_maybe.value()->process_tokens(rt->tokens);
}
else {
+ /* Error message is on index 3 */
+ msg_err_task("cannot classify task: %s",
+ lua_tostring(L, 3));
}
return 0;
return FALSE;
}
- if (rt->results) {
- /* No need to do anything, we have results ready */
- rt->process_tokens(tokens);
+ if (!rt->need_redis_call) {
+ /* No need to do anything, as it is already done in the opposite class processing */
return TRUE;
}
lua_pushstring(L, cookie);
lua_pushcclosure(L, &rspamd_redis_classified, 1);
-
if (lua_pcall(L, 6, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);