]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Further caching logic modifications
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 6 Dec 2023 15:36:52 +0000 (15:36 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 6 Dec 2023 15:36:52 +0000 (15:36 +0000)
lualib/lua_bayes_redis.lua
src/libstat/backends/redis_backend.cxx

index 5dca2db430581004d5057c3f9773333b0b5b23b2..25c56d58b52eb6cf7033316d117fc3c2ca581ee1 100644 (file)
@@ -31,7 +31,7 @@ local function gen_classify_functor(redis_params, classify_script_id)
       if err then
         callback(task, false, err)
       else
-        callback(task, true, data[1], data[2], data[3])
+        callback(task, true, data[1], data[2], data[3], data[4])
       end
     end
 
index 46b27cb155e5acb3fcc625849f519312aba8987b..46f88de19ebfeda6e34b6bb96e4923a4aec92177 100644 (file)
@@ -63,32 +63,58 @@ struct redis_stat_runtime {
        struct redis_stat_ctx *ctx;
        struct rspamd_task *task;
        struct rspamd_statfile_config *stcf;
-       GPtrArray *tokens;
+       GPtrArray *tokens = nullptr;
        const char *redis_object_expanded;
        std::uint64_t learned = 0;
        int id;
        std::vector<std::pair<int, T>> *results = nullptr;
+       bool need_redis_call = true;
 
        using result_type = std::vector<std::pair<int, T>>;
 
-       explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
-               : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+private:
+       /* Called on connection termination */
+       static void rt_dtor(gpointer data)
        {
+               auto *rt = REDIS_RUNTIME(data);
+
+               delete rt;
        }
 
-       void init()
+       /* Avoid occasional deletion */
+       ~redis_stat_runtime()
        {
+               if (tokens) {
+                       g_ptr_array_unref(tokens);
+               }
+
+               delete results;
        }
 
-       void set_results(std::vector<std::pair<int, T>> *_results)
+public:
+       explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+               : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
        {
-               results = _results;
+               rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this);
        }
 
-       ~redis_stat_runtime()
+       static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
+                                                                                  bool is_spam) -> std::optional<redis_stat_runtime<T> *>
        {
-               g_ptr_array_unref(tokens);
-               delete results;
+               auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+               auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
+
+               if (res) {
+                       return reinterpret_cast<redis_stat_runtime<T> *>(res);
+               }
+               else {
+                       return std::nullopt;
+               }
+       }
+
+       void set_results(std::vector<std::pair<int, T>> *results)
+       {
+               this->results = results;
        }
 
        /* Propagate results from internal representation to the tokens array */
@@ -104,6 +130,15 @@ struct redis_stat_runtime {
                        tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx);
                        tok->values[id] = val;
                }
+
+               return true;
+       }
+
+       auto save_in_mempool(bool is_spam) const
+       {
+               auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+               /* We do not set destructor for the variable, as it should be already added on creation */
+               rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
        }
 };
 
@@ -1095,16 +1130,6 @@ rspamd_redis_async_stat_fin(struct rspamd_stat_async_elt *elt, gpointer d)
 #endif
 
 
-/* Called on connection termination */
-static void
-rspamd_redis_fin(gpointer data)
-{
-       auto *rt = REDIS_RUNTIME(data);
-
-       delete rt;
-}
-
-
 static bool
 rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
                                                                   const ucl_object_t *statfile_obj,
@@ -1296,19 +1321,40 @@ rspamd_redis_runtime(struct rspamd_task *task,
                return nullptr;
        }
 
-       auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
-       rspamd_mempool_add_destructor(task->task_pool, rspamd_redis_fin, rt);
-
        /* Look for the cached results */
        if (!learn) {
-               auto var_name = fmt::format("{}_{}", object_expanded, stcf->is_spam ? "S" : "H");
-               auto *res = rspamd_mempool_steal_variable(task->task_pool, var_name.c_str());
+               auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                       object_expanded, stcf->is_spam);
 
-               if (res) {
-                       rt->set_results(reinterpret_cast<redis_stat_runtime<float>::result_type *>(res));
+               if (maybe_existing) {
+                       /* Update stcf to correspond to what we have been asked */
+                       maybe_existing.value()->stcf = stcf;
+                       return maybe_existing.value();
+               }
+       }
+
+       /* No cached result, create new one */
+       auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+
+       if (!learn) {
+               /*
+                * For check, we also need to create the opposite class runtime to avoid
+                * double call for Redis scripts.
+                * This runtime will be filled later.
+                */
+               auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                          object_expanded,
+                                                                                                                                                                          !stcf->is_spam);
+
+               if (!maybe_opposite_rt) {
+                       auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+                       opposite_rt->save_in_mempool(!stcf->is_spam);
+                       opposite_rt->need_redis_call = false;
                }
        }
 
+       rt->save_in_mempool(stcf->is_spam);
+
        return rt;
 }
 
@@ -1385,8 +1431,65 @@ rspamd_redis_classified(lua_State *L)
        bool result = lua_toboolean(L, 2);
 
        if (result) {
+               /* Indexes:
+                * 3 - learned_ham (int)
+                * 4 - learned_spam (int)
+                * 5 - ham_tokens (pair<int, int>)
+                * 6 - spam_tokens (pair<int, int>)
+                */
+
+               /*
+                * We need to fill our runtime AND the opposite runtime
+                */
+               auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
+                       rt->learned = learned;
+                       redis_stat_runtime<float>::result_type *res;
+
+                       res = new redis_stat_runtime<float>::result_type(lua_objlen(L, tokens_pos));
+
+                       for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
+                               lua_rawgeti(L, -1, 1);
+                               auto idx = lua_tointeger(L, -1);
+                               lua_pop(L, 1);
+
+                               lua_rawgeti(L, -1, 2);
+                               auto value = lua_tonumber(L, -1);
+                               lua_pop(L, 1);
+
+                               res->emplace_back(idx, value);
+                       }
+
+                       rt->set_results(res);
+               };
+
+               auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+                                                                                                                                                                          rt->redis_object_expanded,
+                                                                                                                                                                          !rt->stcf->is_spam);
+
+               if (!opposite_rt_maybe) {
+                       msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+
+                       return 0;
+               }
+
+               if (rt->stcf->is_spam) {
+                       filler_func(rt, L, lua_tointeger(L, 4), 6);
+                       filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+               }
+               else {
+                       filler_func(rt, L, lua_tointeger(L, 3), 5);
+                       filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+               }
+
+               /* Process all tokens */
+               g_assert(rt->tokens != nullptr);
+               rt->process_tokens(rt->tokens);
+               opposite_rt_maybe.value()->process_tokens(rt->tokens);
        }
        else {
+               /* Error message is on index 3 */
+               msg_err_task("cannot classify task: %s",
+                                        lua_tostring(L, 3));
        }
 
        return 0;
@@ -1408,9 +1511,8 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
                return FALSE;
        }
 
-       if (rt->results) {
-               /* No need to do anything, we have results ready */
-               rt->process_tokens(tokens);
+       if (!rt->need_redis_call) {
+               /* No need to do anything, as it is already done in the opposite class processing */
 
                return TRUE;
        }
@@ -1440,7 +1542,6 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
        lua_pushstring(L, cookie);
        lua_pushcclosure(L, &rspamd_redis_classified, 1);
 
-
        if (lua_pcall(L, 6, 0, err_idx) != 0) {
                msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
                lua_settop(L, err_idx - 1);