if maybe_text_tokens then
lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = false, key = expanded_key },
+ { task = task, is_write = true, key = expanded_key },
learn_redis_cb,
{ expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
else
lua_redis.exec_redis_script(learn_script_id,
- { task = task, is_write = false, key = expanded_key },
+ { task = task, is_write = true, key = expanded_key },
learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
end
rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
local function stat_redis_cb(err, data)
- -- TODO: write this function
lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
if err then
return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
end
+local function gen_cache_check_functor(redis_params, check_script_id)
+ return function(task, cache_id, callback)
+
+ local function classify_redis_cb(err, data)
+ lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
+ if err then
+ callback(task, false, err)
+ else
+ callback(task, true, data[1], data[2], data[3], data[4])
+ end
+ end
+
+ lua_redis.exec_redis_script(check_script_id,
+ { task = task, is_write = false, key = cache_id },
+ classify_redis_cb, { cache_id })
+ end
+end
+
+local function gen_cache_learn_functor(redis_params, learn_script_id)
+ return function(task, cache_id, callback)
+ local function learn_redis_cb(err, data)
+ lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
+ if err then
+ callback(task, false, err)
+ else
+ callback(task, true)
+ end
+ end
+
+ lua_redis.exec_redis_script(learn_script_id,
+ { task = task, is_write = true, key = cache_id },
+ learn_redis_cb,
+ { cache_id })
+
+ end
+end
+
exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
if not redis_params then
return nil
end
+
+ local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params)
+ local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)
+
+ return gen_cache_check_functor(redis_params, check_script_id), gen_cache_learn_functor(redis_params, learn_script_id)
end
return exports
* limitations under the License.
*/
#include "config.h"
+// Include early to avoid `extern "C"` issues
+#include "lua/lua_common.h"
#include "learn_cache.h"
#include "rspamd.h"
#include "stat_api.h"
#include "stat_internal.h"
#include "cryptobox.h"
#include "ucl.h"
-#include "hiredis.h"
-#include "adapters/libev.h"
-#include "lua/lua_common.h"
#include "libmime/message.h"
#define DEFAULT_REDIS_KEY "learned_ids"
{
std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg));
- const auto *obj = ucl_object_lookup(st->classifier->cfg->opts, "cache_key");
+ auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+ lua_settop(L, 0);
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ auto err_idx = lua_gettop(L);
+
+ /* Obtain function */
+ if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) {
+ msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache");
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /* Push arguments */
+ ucl_object_push_lua(L, st->classifier->cfg->opts, false);
+ ucl_object_push_lua(L, st->stcf->opts, false);
+
+ if (lua_pcall(L, 2, 2, err_idx) != 0) {
+ msg_err("call to lua_bayes_init_cache "
+ "script failed: %s",
+ lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
- if (obj) {
- cache_ctx->redis_object = ucl_object_tostring(obj);
+ return nullptr;
}
- cache_ctx->stcf = st->stcf;
+ /*
+ * Results are in the stack:
+ * top - 1 - check function (idx = -2)
+ * top - learn function (idx = -1)
+ */
+ lua_pushvalue(L, -2);
+ cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_pushvalue(L, -1);
+ cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_settop(L, err_idx - 1);
return (gpointer) cache_ctx.release();
}