From: Vsevolod Stakhov Date: Fri, 12 Jan 2024 15:41:11 +0000 (+0000) Subject: [Project] Initial implementation of the lua counterpart X-Git-Tag: 3.8.0~8^2~6 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=7541d281d3376afd1216426d843981b4014569cc;p=rspamd.git [Project] Initial implementation of the lua counterpart --- diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index 6f6da339e..3321c96c3 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -54,12 +54,12 @@ local function gen_learn_functor(redis_params, learn_script_id) if maybe_text_tokens then lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = false, key = expanded_key }, + { task = task, is_write = true, key = expanded_key }, learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) else lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = false, key = expanded_key }, + { task = task, is_write = true, key = expanded_key }, learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) end @@ -136,7 +136,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _) local function stat_redis_cb(err, data) - -- TODO: write this function lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data) if err then @@ -172,12 +171,54 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id) end +local function gen_cache_check_functor(redis_params, check_script_id) + return function(task, cache_id, callback) + + local function classify_redis_cb(err, data) + lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data) + if err then + callback(task, false, err) + else + callback(task, true, data[1], data[2], data[3], data[4]) + end + end + + lua_redis.exec_redis_script(check_script_id, + { task = task, is_write = false, key = cache_id }, + classify_redis_cb, { cache_id }) + end +end + +local function gen_cache_learn_functor(redis_params, learn_script_id) + return function(task, cache_id, callback) + local function learn_redis_cb(err, data) + lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data) + if err then + callback(task, false, err) + else + callback(task, true) + end + end + + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = true, key = cache_id }, + learn_redis_cb, + { cache_id }) + + end +end + exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl) local redis_params = load_redis_params(classifier_ucl, statfile_ucl) if not redis_params then return nil end + + local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params) + local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params) + + return gen_cache_check_functor(redis_params, check_script_id), gen_cache_learn_functor(redis_params, learn_script_id) end return exports diff --git a/lualib/redis_scripts/bayes_cache_check.lua b/lualib/redis_scripts/bayes_cache_check.lua new file mode 100644 index 000000000..e69de29bb diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua new file mode 100644 index 000000000..e69de29bb diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx index 8aba739d1..b774e626e 100644 --- a/src/libstat/learn_cache/redis_cache.cxx +++ b/src/libstat/learn_cache/redis_cache.cxx @@ -14,15 +14,14 @@ * limitations under the License. */ #include "config.h" +// Include early to avoid `extern "C"` issues +#include "lua/lua_common.h" #include "learn_cache.h" #include "rspamd.h" #include "stat_api.h" #include "stat_internal.h" #include "cryptobox.h" #include "ucl.h" -#include "hiredis.h" -#include "adapters/libev.h" -#include "lua/lua_common.h" #include "libmime/message.h" #define DEFAULT_REDIS_KEY "learned_ids" @@ -153,13 +152,45 @@ rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx, { std::unique_ptr cache_ctx = std::make_unique(RSPAMD_LUA_CFG_STATE(cfg)); - const auto *obj = ucl_object_lookup(st->classifier->cfg->opts, "cache_key"); + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + lua_settop(L, 0); + + lua_pushcfunction(L, &rspamd_lua_traceback); + auto err_idx = lua_gettop(L); + + /* Obtain function */ + if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) { + msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache"); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* Push arguments */ + ucl_object_push_lua(L, st->classifier->cfg->opts, false); + ucl_object_push_lua(L, st->stcf->opts, false); + + if (lua_pcall(L, 2, 2, err_idx) != 0) { + msg_err("call to lua_bayes_init_cache " + "script failed: %s", + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); - if (obj) { - cache_ctx->redis_object = ucl_object_tostring(obj); + return nullptr; } - cache_ctx->stcf = st->stcf; + /* + * Results are in the stack: + * top - 1 - check function (idx = -2) + * top - learn function (idx = -1) + */ + lua_pushvalue(L, -2); + cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushvalue(L, -1); + cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_settop(L, err_idx - 1); return (gpointer) cache_ctx.release(); }