]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add some basic learning
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 7 Dec 2023 15:40:08 +0000 (15:40 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 7 Dec 2023 15:40:08 +0000 (15:40 +0000)
lualib/lua_bayes_redis.lua
lualib/redis_scripts/bayes_learn.lua
src/libstat/backends/redis_backend.cxx

index 2286295d5f049e3422b3dc426e2b9efec9737e70..ee4ec80b6b7e077bb172b588caab2f45cb1b6e36 100644 (file)
@@ -54,7 +54,7 @@ local function gen_learn_functor(redis_params, learn_script_id)
 
     lua_redis.exec_redis_script(learn_script_id,
         { task = task, is_write = false, key = expanded_key },
-        learn_redis_cb, { expanded_key, is_spam, symbol, is_unlearn, stat_tokens })
+        learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
   end
 end
 
index 2b74fcca97fd7762bcee558db6f464329c2be207..6382547067b0561ce0988548b976d5070a97fd9d 100644 (file)
@@ -7,9 +7,9 @@
 -- key5 - set of tokens encoded in messagepack array of int64_t
 
 local prefix = KEYS[1]
-local is_spam = KEYS[2]
+local is_spam = KEYS[2] == 'true' and true or false
 local symbol = KEYS[3]
-local is_unlearn = KEYS[4]
+local is_unlearn = KEYS[4] == 'true' and true or false
 local input_tokens = cmsgpack.unpack(KEYS[5])
 
 local prefix_underscore = prefix .. '_'
index f3104967bbab6df792cdecfa32dd8e57f7eaa914..342fa02739227f5ddedc47331c79800140a3b713 100644 (file)
@@ -497,6 +497,32 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
                backend->max_users = REDIS_MAX_USERS;
        }
 
+       return true;
+}
+
+gpointer
+rspamd_redis_init(struct rspamd_stat_ctx *ctx,
+                                 struct rspamd_config *cfg, struct rspamd_statfile *st)
+{
+       gint conf_ref = -1;
+       auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+
+       auto backend = std::make_unique<struct redis_stat_ctx>();
+       backend->L = L;
+       backend->max_users = REDIS_MAX_USERS;
+
+       backend->conf_ref = conf_ref;
+
+       lua_settop(L, 0);
+
+       if (!rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg)) {
+               msg_err_config("cannot init redis backend for %s", st->stcf->symbol);
+               return nullptr;
+       }
+
+       st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+       backend->stcf = st->stcf;
+
        lua_pushcfunction(L, &rspamd_lua_traceback);
        auto err_idx = lua_gettop(L);
 
@@ -505,19 +531,19 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
                msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_classifier");
                lua_settop(L, err_idx - 1);
 
-               return false;
+               return nullptr;
        }
 
        /* Push arguments */
-       ucl_object_push_lua(L, classifier_obj, false);
-       ucl_object_push_lua(L, statfile_obj, false);
+       ucl_object_push_lua(L, st->classifier->cfg->opts, false);
+       ucl_object_push_lua(L, st->stcf->opts, false);
        lua_pushstring(L, backend->stcf->symbol);
 
        /* Store backend in random cookie */
        char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16);
        rspamd_random_hex(cookie, 16);
        cookie[15] = '\0';
-       rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend, nullptr);
+       rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr);
        /* Callback */
        lua_pushstring(L, cookie);
        lua_pushcclosure(L, &rspamd_redis_stat_cb, 1);
@@ -528,7 +554,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
                                lua_tostring(L, -1));
                lua_settop(L, err_idx - 1);
 
-               return false;
+               return nullptr;
        }
 
        /* Results are in the stack:
@@ -544,43 +570,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
 
        lua_settop(L, err_idx - 1);
 
-       return true;
-}
-
-gpointer
-rspamd_redis_init(struct rspamd_stat_ctx *ctx,
-                                 struct rspamd_config *cfg, struct rspamd_statfile *st)
-{
-       gint conf_ref = -1;
-       auto *L = (lua_State *) cfg->lua_state;
-
-       auto *backend = g_new0(struct redis_stat_ctx, 1);
-       backend->L = L;
-       backend->max_users = REDIS_MAX_USERS;
-
-       backend->conf_ref = conf_ref;
-
-       lua_settop(L, 0);
-
-       if (!rspamd_redis_parse_classifier_opts(backend, st->stcf->opts, st->classifier->cfg->opts, cfg)) {
-               msg_err_config("cannot init redis backend for %s", st->stcf->symbol);
-               g_free(backend);
-               return nullptr;
-       }
-
-       st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
-       backend->stcf = st->stcf;
-
-#if 0
-       backend->stat_elt = rspamd_stat_ctx_register_async(
-               rspamd_redis_async_stat_cb,
-               rspamd_redis_async_stat_fin,
-               st_elt,
-               REDIS_STAT_TIMEOUT);
-       st_elt->async = backend->stat_elt;
-#endif
-
-       return (gpointer) backend;
+       return backend.release();
 }
 
 gpointer
@@ -615,7 +605,7 @@ rspamd_redis_runtime(struct rspamd_task *task,
                }
        }
 
-       /* No cached result, create new one */
+       /* No cached result (or learn), create new one */
        auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
 
        if (!learn) {
@@ -645,6 +635,7 @@ void rspamd_redis_close(gpointer p)
        struct redis_stat_ctx *ctx = REDIS_CTX(p);
        lua_State *L = ctx->L;
 
+       /* TODO: move to dtor */
        if (ctx->conf_ref) {
                luaL_unref(L, LUA_REGISTRYINDEX, ctx->conf_ref);
        }
@@ -657,7 +648,7 @@ void rspamd_redis_close(gpointer p)
                luaL_unref(L, LUA_REGISTRYINDEX, ctx->cbref_classify);
        }
 
-       g_free(ctx);
+       delete ctx;
 }
 
 /*
@@ -702,7 +693,6 @@ rspamd_redis_classified(lua_State *L)
        const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
        auto *task = lua_check_task(L, 1);
        auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
-       /* TODO: write it */
 
        if (rt == nullptr) {
                msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
@@ -843,15 +833,96 @@ rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
        return TRUE;
 }
 
+
+static gint
+rspamd_redis_learned(lua_State *L)
+{
+       const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+       auto *task = lua_check_task(L, 1);
+       auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+
+       if (rt == nullptr) {
+               msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
+
+               return 0;
+       }
+
+       bool result = lua_toboolean(L, 2);
+
+       if (result) {
+               /* TODO: write it */
+       }
+       else {
+               /* Error message is on index 3 */
+               msg_err_task("cannot learn task: %s",
+                                        lua_tostring(L, 3));
+       }
+
+       return 0;
+}
+
 gboolean
-rspamd_redis_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
+rspamd_redis_learn_tokens(struct rspamd_task *task,
+                                                 GPtrArray *tokens,
                                                  gint id, gpointer p)
 {
        auto *rt = REDIS_RUNTIME(p);
+       auto *L = rt->ctx->L;
+
+       if (rspamd_session_blocked(task->s)) {
+               return FALSE;
+       }
+
+       if (tokens == nullptr || tokens->len == 0) {
+               return FALSE;
+       }
+
+       gsize tokens_len;
+       gchar *tokens_buf = rspamd_redis_serialize_tokens(task, tokens, &tokens_len);
+
+       rt->id = id;
+
+       lua_pushcfunction(L, &rspamd_lua_traceback);
+       gint err_idx = lua_gettop(L);
+
+       /* Function arguments */
+       lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn);
+       rspamd_lua_task_push(L, task);
+       lua_pushstring(L, rt->redis_object_expanded);
+       lua_pushinteger(L, id);
+       lua_pushboolean(L, rt->stcf->is_spam);
+       lua_pushstring(L, rt->stcf->symbol);
+
+       /* Detect unlearn */
+       auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0);
+
+       if (tok->values[id] > 0) {
+               lua_pushboolean(L, FALSE);// Learn
+       }
+       else {
+               lua_pushboolean(L, TRUE);// Unlearn
+       }
+       lua_new_text(L, tokens_buf, tokens_len, false);
+
+       /* Store rt in random cookie */
+       char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
+       rspamd_random_hex(cookie, 16);
+       cookie[15] = '\0';
+       rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
+       /* Callback */
+       lua_pushstring(L, cookie);
+       lua_pushcclosure(L, &rspamd_redis_learned, 1);
 
-       /* TODO: write learn function */
+       if (lua_pcall(L, 8, 0, err_idx) != 0) {
+               msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+               lua_settop(L, err_idx - 1);
+               return FALSE;
+       }
 
-       return FALSE;
+       rt->tokens = g_ptr_array_ref(tokens);
+
+       lua_settop(L, err_idx - 1);
+       return TRUE;
 }