diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-12-07 15:40:08 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-12-07 15:40:08 +0000 |
commit | 752414a1f585b5ad3b03d812eb2aa995e899e9b3 (patch) | |
tree | 70faa7e3b483118473fa83c8f10ba59795a09d40 | |
parent | 51e123f58ff32e362ba35f077d37e3a7d04de5a5 (diff) | |
download | rspamd-752414a1f585b5ad3b03d812eb2aa995e899e9b3.tar.gz rspamd-752414a1f585b5ad3b03d812eb2aa995e899e9b3.zip |
[Project] Add some basic learning
-rw-r--r-- | lualib/lua_bayes_redis.lua | 2 | ||||
-rw-r--r-- | lualib/redis_scripts/bayes_learn.lua | 4 | ||||
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 167 |
3 files changed, 122 insertions, 51 deletions
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index 2286295d5..ee4ec80b6 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -54,7 +54,7 @@ local function gen_learn_functor(redis_params, learn_script_id) lua_redis.exec_redis_script(learn_script_id, { task = task, is_write = false, key = expanded_key }, - learn_redis_cb, { expanded_key, is_spam, symbol, is_unlearn, stat_tokens }) + learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) end end diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua index 2b74fcca9..638254706 100644 --- a/lualib/redis_scripts/bayes_learn.lua +++ b/lualib/redis_scripts/bayes_learn.lua @@ -7,9 +7,9 @@ -- key5 - set of tokens encoded in messagepack array of int64_t local prefix = KEYS[1] -local is_spam = KEYS[2] +local is_spam = KEYS[2] == 'true' and true or false local symbol = KEYS[3] -local is_unlearn = KEYS[4] +local is_unlearn = KEYS[4] == 'true' and true or false local input_tokens = cmsgpack.unpack(KEYS[5]) local prefix_underscore = prefix .. '_' diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index f3104967b..342fa0273 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -497,6 +497,32 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, backend->max_users = REDIS_MAX_USERS; } + return true; +} + +gpointer +rspamd_redis_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, struct rspamd_statfile *st) +{ + gint conf_ref = -1; + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + + auto backend = std::make_unique<struct redis_stat_ctx>(); + backend->L = L; + backend->max_users = REDIS_MAX_USERS; + + backend->conf_ref = conf_ref; + + lua_settop(L, 0); + + if (!rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg)) { + msg_err_config("cannot init redis backend for %s", st->stcf->symbol); + return nullptr; + } + + st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; + backend->stcf = st->stcf; + lua_pushcfunction(L, &rspamd_lua_traceback); auto err_idx = lua_gettop(L); @@ -505,19 +531,19 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_classifier"); lua_settop(L, err_idx - 1); - return false; + return nullptr; } /* Push arguments */ - ucl_object_push_lua(L, classifier_obj, false); - ucl_object_push_lua(L, statfile_obj, false); + ucl_object_push_lua(L, st->classifier->cfg->opts, false); + ucl_object_push_lua(L, st->stcf->opts, false); lua_pushstring(L, backend->stcf->symbol); /* Store backend in random cookie */ char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16); rspamd_random_hex(cookie, 16); cookie[15] = '\0'; - rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend, nullptr); + rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr); /* Callback */ lua_pushstring(L, cookie); lua_pushcclosure(L, &rspamd_redis_stat_cb, 1); @@ -528,7 +554,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, lua_tostring(L, -1)); lua_settop(L, err_idx - 1); - return false; + return nullptr; } /* Results are in the stack: @@ -544,43 +570,7 @@ rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, lua_settop(L, err_idx - 1); - return true; -} - -gpointer -rspamd_redis_init(struct rspamd_stat_ctx *ctx, - struct rspamd_config *cfg, struct rspamd_statfile *st) -{ - gint conf_ref = -1; - auto *L = (lua_State *) cfg->lua_state; - - auto *backend = g_new0(struct redis_stat_ctx, 1); - backend->L = L; - backend->max_users = REDIS_MAX_USERS; - - backend->conf_ref = conf_ref; - - lua_settop(L, 0); - - if (!rspamd_redis_parse_classifier_opts(backend, st->stcf->opts, st->classifier->cfg->opts, cfg)) { - msg_err_config("cannot init redis backend for %s", st->stcf->symbol); - g_free(backend); - return nullptr; - } - - st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; - backend->stcf = st->stcf; - -#if 0 - backend->stat_elt = rspamd_stat_ctx_register_async( - rspamd_redis_async_stat_cb, - rspamd_redis_async_stat_fin, - st_elt, - REDIS_STAT_TIMEOUT); - st_elt->async = backend->stat_elt; -#endif - - return (gpointer) backend; + return backend.release(); } gpointer @@ -615,7 +605,7 @@ rspamd_redis_runtime(struct rspamd_task *task, } } - /* No cached result, create new one */ + /* No cached result (or learn), create new one */ auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); if (!learn) { @@ -645,6 +635,7 @@ void rspamd_redis_close(gpointer p) struct redis_stat_ctx *ctx = REDIS_CTX(p); lua_State *L = ctx->L; + /* TODO: move to dtor */ if (ctx->conf_ref) { luaL_unref(L, LUA_REGISTRYINDEX, ctx->conf_ref); } @@ -657,7 +648,7 @@ void rspamd_redis_close(gpointer p) luaL_unref(L, LUA_REGISTRYINDEX, ctx->cbref_classify); } - g_free(ctx); + delete ctx; } /* @@ -702,7 +693,6 @@ rspamd_redis_classified(lua_State *L) const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); auto *task = lua_check_task(L, 1); auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); - /* TODO: write it */ if (rt == nullptr) { msg_err_task("internal error: cannot find runtime for cookie %s", cookie); @@ -843,15 +833,96 @@ rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime, return TRUE; } + +static gint +rspamd_redis_learned(lua_State *L) +{ + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *task = lua_check_task(L, 1); + auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + + if (rt == nullptr) { + msg_err_task("internal error: cannot find runtime for cookie %s", cookie); + + return 0; + } + + bool result = lua_toboolean(L, 2); + + if (result) { + /* TODO: write it */ + } + else { + /* Error message is on index 3 */ + msg_err_task("cannot learn task: %s", + lua_tostring(L, 3)); + } + + return 0; +} + gboolean -rspamd_redis_learn_tokens(struct rspamd_task *task, GPtrArray *tokens, +rspamd_redis_learn_tokens(struct rspamd_task *task, + GPtrArray *tokens, gint id, gpointer p) { auto *rt = REDIS_RUNTIME(p); + auto *L = rt->ctx->L; + + if (rspamd_session_blocked(task->s)) { + return FALSE; + } + + if (tokens == nullptr || tokens->len == 0) { + return FALSE; + } + + gsize tokens_len; + gchar *tokens_buf = rspamd_redis_serialize_tokens(task, tokens, &tokens_len); + + rt->id = id; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn); + rspamd_lua_task_push(L, task); + lua_pushstring(L, rt->redis_object_expanded); + lua_pushinteger(L, id); + lua_pushboolean(L, rt->stcf->is_spam); + lua_pushstring(L, rt->stcf->symbol); + + /* Detect unlearn */ + auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0); + + if (tok->values[id] > 0) { + lua_pushboolean(L, FALSE);// Learn + } + else { + lua_pushboolean(L, TRUE);// Unlearn + } + lua_new_text(L, tokens_buf, tokens_len, false); + + /* Store rt in random cookie */ + char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr); + /* Callback */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_learned, 1); - /* TODO: write learn function */ + if (lua_pcall(L, 8, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return FALSE; + } - return FALSE; + rt->tokens = g_ptr_array_ref(tokens); + + lua_settop(L, err_idx - 1); + return TRUE; } |