aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lualib/lua_bayes_redis.lua54
-rw-r--r--lualib/redis_scripts/bayes_stat.lua19
-rw-r--r--src/libstat/backends/redis_backend.cxx43
-rw-r--r--src/libstat/classifiers/classifiers.h20
4 files changed, 115 insertions, 21 deletions
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua
index ee4ec80b6..e085694a9 100644
--- a/lualib/lua_bayes_redis.lua
+++ b/lualib/lua_bayes_redis.lua
@@ -63,7 +63,7 @@ end
--- @param classifier_ucl ucl of the classifier config
--- @param statfile_ucl ucl of the statfile config
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
-exports.lua_bayes_init_classifier = function(classifier_ucl, statfile_ucl, symbol, stat_periodic_cb)
+exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
local redis_params
if classifier_ucl.backend then
@@ -92,20 +92,50 @@ exports.lua_bayes_init_classifier = function(classifier_ucl, statfile_ucl, symbo
local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params)
local max_users = classifier_ucl.max_users or 1000
- rspamd_config:add_on_load(function(_, ev_base, _)
-
- rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
-
- local function stat_redis_cb(err, data)
- -- TODO: write this function
+ local current_data = {
+ users = 0,
+ revision = 0,
+ }
+ local final_data = {
+ users = 0,
+ revision = 0, -- number of learns
+ }
+ local cursor = 0
+ rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
+
+ local function stat_redis_cb(err, data)
+ -- TODO: write this function
+ lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
+ if err then
+ logger.warn(cfg, 'cannot get bayes statistics for %s: %s', symbol, err)
+ else
+ local new_cursor = data[1]
+ if new_cursor == 0 then
+ -- Done iteration
+ final_data = current_data
+ current_data = {
+ users = 0,
+ revision = 0,
+ }
+ stat_periodic_cb(cfg, final_data)
+ else
+ -- Collect more data
+ current_data.users = current_data.users + data[2]
+ current_data.revision = current_data.revision + data[3]
+ end
+
+ cursor = new_cursor
end
+ end
- lua_redis.exec_redis_script(stat_script_id,
- { ev_base = ev_base, cfg = cfg, is_write = false },
- stat_redis_cb, { symbol, max_users })
- return 30.0 -- TODO: make configurable
- end)
+ lua_redis.exec_redis_script(stat_script_id,
+ { ev_base = ev_base, cfg = cfg, is_write = false },
+ stat_redis_cb, { tostring(cursor),
+ symbol,
+ is_spam and "learns_spam" or "learns_ham",
+ tostring(max_users) })
+ return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
end)
return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
diff --git a/lualib/redis_scripts/bayes_stat.lua b/lualib/redis_scripts/bayes_stat.lua
index e69de29bb..31e51280c 100644
--- a/lualib/redis_scripts/bayes_stat.lua
+++ b/lualib/redis_scripts/bayes_stat.lua
@@ -0,0 +1,19 @@
+-- Lua script to perform bayes stats
+-- This script accepts the following parameters:
+-- key1 - current cursor
+-- key2 - symbol to examine
+-- key3 - learn key (e.g. learns_ham or learns_spam)
+-- key4 - max users
+
+local cursor = tonumber(KEYS[1])
+
+local ret = redis.call('SSCAN', KEYS[2] .. '_keys', cursor, 'COUNT', tonumber(KEYS[4]))
+
+local new_cursor = tonumber(ret[1])
+local nkeys = #ret[2]
+local learns = 0
+for _, key in ipairs(ret[2]) do
+ learns = learns + (tonumber(redis.call('HGET', key, KEYS[3])) or 0)
+end
+
+return { new_cursor, nkeys, learns } \ No newline at end of file
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index 5d6a438d8..30dd13107 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -53,6 +53,8 @@ struct redis_stat_ctx {
int cbref_classify = -1;
int cbref_learn = -1;
+ ucl_object_t *cur_stat = nullptr;
+
explicit redis_stat_ctx(lua_State *_L)
: L(_L)
{
@@ -404,6 +406,29 @@ rspamd_redis_stat_cb(lua_State *L)
return 0;
}
+ auto *cur_obj = ucl_object_lua_import(L, 2);
+ msg_debug_bayes_cfg("got stat object for %s", backend->stcf->symbol);
+ /* Enrich with some default values that are meaningless for redis */
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "used", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "total", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "size", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_fromstring(backend->stcf->symbol),
+ "symbol", 0, false);
+ ucl_object_insert_key(cur_obj, ucl_object_fromstring("redis"),
+ "type", 0, false);
+ ucl_object_insert_key(cur_obj, ucl_object_fromint(0),
+ "languages", 0, false);
+
+ if (backend->cur_stat) {
+ ucl_object_unref(backend->cur_stat);
+ }
+
+ backend->cur_stat = cur_obj;
+
return 0;
}
@@ -502,8 +527,8 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx,
auto err_idx = lua_gettop(L);
/* Obtain function */
- if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_classifier")) {
- msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_classifier");
+ if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_statfile")) {
+ msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_statfile");
lua_settop(L, err_idx - 1);
return nullptr;
@@ -513,17 +538,21 @@ rspamd_redis_init(struct rspamd_stat_ctx *ctx,
ucl_object_push_lua(L, st->classifier->cfg->opts, false);
ucl_object_push_lua(L, st->stcf->opts, false);
lua_pushstring(L, backend->stcf->symbol);
+ lua_pushboolean(L, backend->stcf->is_spam);
+ auto **pev_base = (struct ev_loop **) lua_newuserdata(L, sizeof(struct ev_loop *));
+ *pev_base = ctx->event_loop;
+ rspamd_lua_setclass(L, "rspamd{ev_base}", -1);
/* Store backend in random cookie */
char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16);
rspamd_random_hex(cookie, 16);
cookie[15] = '\0';
rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr);
- /* Callback */
+ /* Callback + 1 upvalue */
lua_pushstring(L, cookie);
lua_pushcclosure(L, &rspamd_redis_stat_cb, 1);
- if (lua_pcall(L, 4, 2, err_idx) != 0) {
+ if (lua_pcall(L, 6, 2, err_idx) != 0) {
msg_err("call to lua_bayes_init_classifier "
"script failed: %s",
lua_tostring(L, -1));
@@ -942,12 +971,8 @@ rspamd_redis_get_stat(gpointer runtime,
gpointer ctx)
{
auto *rt = REDIS_RUNTIME(runtime);
- struct rspamd_redis_stat_elt *st;
- redisAsyncContext *redis;
- /* TODO: write extraction */
-
- return nullptr;
+ return ucl_object_ref(rt->ctx->cur_stat);
}
gpointer
diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h
index f6109c3e5..949408c6b 100644
--- a/src/libstat/classifiers/classifiers.h
+++ b/src/libstat/classifiers/classifiers.h
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2023 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
#ifndef CLASSIFIERS_H
#define CLASSIFIERS_H
@@ -80,6 +96,10 @@ extern gint rspamd_bayes_log_id;
rspamd_bayes_log_id, "bayes", task->task_pool->tag.uid, \
G_STRFUNC, \
__VA_ARGS__)
+#define msg_debug_bayes_cfg(...) rspamd_conditional_debug_fast(NULL, NULL, \
+ rspamd_bayes_log_id, "bayes", cfg->cfg_pool->tag.uid, \
+ G_STRFUNC, \
+ __VA_ARGS__)
#ifdef __cplusplus