]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Initial implementation of the lua counterpart
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 12 Jan 2024 15:41:11 +0000 (15:41 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 12 Jan 2024 15:41:11 +0000 (15:41 +0000)
lualib/lua_bayes_redis.lua
lualib/redis_scripts/bayes_cache_check.lua [new file with mode: 0644]
lualib/redis_scripts/bayes_cache_learn.lua [new file with mode: 0644]
src/libstat/learn_cache/redis_cache.cxx

index 6f6da339e12a423cbe86c2aaf80d90d3eee7ba8a..3321c96c3c7ac9637766c34a2a09cb5c367a26f6 100644 (file)
@@ -54,12 +54,12 @@ local function gen_learn_functor(redis_params, learn_script_id)
 
     if maybe_text_tokens then
       lua_redis.exec_redis_script(learn_script_id,
-          { task = task, is_write = false, key = expanded_key },
+          { task = task, is_write = true, key = expanded_key },
           learn_redis_cb,
           { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
     else
       lua_redis.exec_redis_script(learn_script_id,
-          { task = task, is_write = false, key = expanded_key },
+          { task = task, is_write = true, key = expanded_key },
           learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
     end
 
@@ -136,7 +136,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
   rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
 
     local function stat_redis_cb(err, data)
-      -- TODO: write this function
       lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
 
       if err then
@@ -172,12 +171,54 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
   return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
 end
 
+local function gen_cache_check_functor(redis_params, check_script_id)
+  return function(task, cache_id, callback)
+
+    local function classify_redis_cb(err, data)
+      lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
+      if err then
+        callback(task, false, err)
+      else
+        callback(task, true, data[1], data[2], data[3], data[4])
+      end
+    end
+
+    lua_redis.exec_redis_script(check_script_id,
+        { task = task, is_write = false, key = cache_id },
+        classify_redis_cb, { cache_id })
+  end
+end
+
+local function gen_cache_learn_functor(redis_params, learn_script_id)
+  return function(task, cache_id, callback)
+    local function learn_redis_cb(err, data)
+      lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
+      if err then
+        callback(task, false, err)
+      else
+        callback(task, true)
+      end
+    end
+
+    lua_redis.exec_redis_script(learn_script_id,
+        { task = task, is_write = true, key = cache_id },
+        learn_redis_cb,
+        { cache_id })
+
+  end
+end
+
 exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
   local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
 
   if not redis_params then
     return nil
   end
+
+  local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params)
+  local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)
+
+  return gen_cache_check_functor(redis_params, check_script_id), gen_cache_learn_functor(redis_params, learn_script_id)
 end
 
 return exports
diff --git a/lualib/redis_scripts/bayes_cache_check.lua b/lualib/redis_scripts/bayes_cache_check.lua
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua
new file mode 100644 (file)
index 0000000..e69de29
index 8aba739d1dcfd68fb1ffee35b85adce7c1fa3e57..b774e626ed99ecabce245cea8dbef7e2cd2e1bbe 100644 (file)
  * limitations under the License.
  */
 #include "config.h"
+// Include early to avoid `extern "C"` issues
+#include "lua/lua_common.h"
 #include "learn_cache.h"
 #include "rspamd.h"
 #include "stat_api.h"
 #include "stat_internal.h"
 #include "cryptobox.h"
 #include "ucl.h"
-#include "hiredis.h"
-#include "adapters/libev.h"
-#include "lua/lua_common.h"
 #include "libmime/message.h"
 
 #define DEFAULT_REDIS_KEY "learned_ids"
@@ -153,13 +152,45 @@ rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
 {
        std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg));
 
-       const auto *obj = ucl_object_lookup(st->classifier->cfg->opts, "cache_key");
+       auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+       lua_settop(L, 0);
+
+       lua_pushcfunction(L, &rspamd_lua_traceback);
+       auto err_idx = lua_gettop(L);
+
+       /* Obtain function */
+       if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) {
+               msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache");
+               lua_settop(L, err_idx - 1);
+
+               return nullptr;
+       }
+
+       /* Push arguments */
+       ucl_object_push_lua(L, st->classifier->cfg->opts, false);
+       ucl_object_push_lua(L, st->stcf->opts, false);
+
+       if (lua_pcall(L, 2, 2, err_idx) != 0) {
+               msg_err("call to lua_bayes_init_cache "
+                               "script failed: %s",
+                               lua_tostring(L, -1));
+               lua_settop(L, err_idx - 1);
 
-       if (obj) {
-               cache_ctx->redis_object = ucl_object_tostring(obj);
+               return nullptr;
        }
 
-       cache_ctx->stcf = st->stcf;
+       /*
+        * Results are in the stack:
+        * top - 1 - check function (idx = -2)
+        * top - learn function (idx = -1)
+        */
+       lua_pushvalue(L, -2);
+       cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+       lua_pushvalue(L, -1);
+       cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+       lua_settop(L, err_idx - 1);
 
        return (gpointer) cache_ctx.release();
 }