瀏覽代碼

Merge pull request #4774 from rspamd/vstakhov-redis-cache-rework

Rewrite redis_cache logic in statistics
tags/3.8.0
Vsevolod Stakhov 4 月之前
父節點
當前提交
01d182c863
No account linked to committer's email address

+ 89
- 9
lualib/lua_bayes_redis.lua 查看文件

@@ -20,6 +20,7 @@ local exports = {}
local lua_redis = require "lua_redis"
local logger = require "rspamd_logger"
local lua_util = require "lua_util"
local ucl = require "ucl"

local N = "bayes"

@@ -54,24 +55,19 @@ local function gen_learn_functor(redis_params, learn_script_id)

if maybe_text_tokens then
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = false, key = expanded_key },
{ task = task, is_write = true, key = expanded_key },
learn_redis_cb,
{ expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
else
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = false, key = expanded_key },
{ task = task, is_write = true, key = expanded_key },
learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
end

end
end

---
--- Init bayes classifier
--- @param classifier_ucl ucl of the classifier config
--- @param statfile_ucl ucl of the statfile config
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
local function load_redis_params(classifier_ucl, statfile_ucl)
local redis_params

-- Try load from statfile options
@@ -108,6 +104,22 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
return nil
end

return redis_params
end

---
--- Init bayes classifier
--- @param classifier_ucl ucl of the classifier config
--- @param statfile_ucl ucl of the statfile config
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)

local redis_params = load_redis_params(classifier_ucl, statfile_ucl)

if not redis_params then
return nil
end

local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params)
local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params)
local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params)
@@ -125,7 +137,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)

local function stat_redis_cb(err, data)
-- TODO: write this function
lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)

if err then
@@ -161,4 +172,73 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
end

local function gen_cache_check_functor(redis_params, check_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
return function(task, cache_id, callback)

local function classify_redis_cb(err, data)
lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
if err then
callback(task, false, err)
else
if type(data) == 'number' then
callback(task, true, data)
else
callback(task, false, 'not found')
end
end
end

lua_util.debugm(N, task, 'checking cache: %s', cache_id)
lua_redis.exec_redis_script(check_script_id,
{ task = task, is_write = false, key = cache_id },
classify_redis_cb, { cache_id, packed_conf })
end
end

local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
return function(task, cache_id, is_spam)
local function learn_redis_cb(err, data)
lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
end

lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = cache_id },
learn_redis_cb,
{ cache_id, is_spam and "1" or "0", packed_conf })

end
end

exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
local redis_params = load_redis_params(classifier_ucl, statfile_ucl)

if not redis_params then
return nil
end

local default_conf = {
cache_prefix = "learned_ids",
cache_max_elt = 10000, -- Maximum number of elements in the cache key
cache_max_keys = 5, -- Maximum number of keys in the cache
cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value)
}

local conf = lua_util.override_defaults(default_conf, classifier_ucl)
-- Clean all not known configurations
for k, _ in pairs(conf) do
if default_conf[k] == nil then
conf[k] = nil
end
end

local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params)
local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)

return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params,
learn_script_id, conf)
end

return exports

+ 20
- 0
lualib/redis_scripts/bayes_cache_check.lua 查看文件

@@ -0,0 +1,20 @@
-- Lua script to perform cache checking for bayes classification
-- This script accepts the following parameters:
-- key1 - cache id
-- key2 - configuration table in message pack

local cache_id = KEYS[1]
local conf = cmsgpack.unpack(KEYS[2])
cache_id = string.sub(cache_id, 1, conf.cache_elt_len)

-- Try each prefix that is in Redis
for i = 0, conf.cache_max_keys do
local prefix = conf.cache_prefix .. string.rep("X", i)
local have = redis.call('HGET', prefix, cache_id)

if have then
return tonumber(have)
end
end

return nil

+ 61
- 0
lualib/redis_scripts/bayes_cache_learn.lua 查看文件

@@ -0,0 +1,61 @@
-- Lua script to perform cache checking for bayes classification
-- This script accepts the following parameters:
-- key1 - cache id
-- key3 - is spam (1 or 0)
-- key3 - configuration table in message pack

local cache_id = KEYS[1]
local is_spam = KEYS[2]
local conf = cmsgpack.unpack(KEYS[3])
cache_id = string.sub(cache_id, 1, conf.cache_elt_len)

-- Try each prefix that is in Redis (as some other instance might have set it)
for i = 0, conf.cache_max_keys do
local prefix = conf.cache_prefix .. string.rep("X", i)
local have = redis.call('HGET', prefix, cache_id)

if have then
-- Already in cache
return false
end
end

local added = false
local lim = conf.cache_max_elt
for i = 0, conf.cache_max_keys do
if not added then
local prefix = conf.cache_prefix .. string.rep("X", i)
local count = redis.call('HLEN', prefix)

if count < lim then
-- We can add it to this prefix
redis.call('HSET', prefix, cache_id, is_spam)
added = true
end
end
end

if not added then
-- Need to expire some keys
local expired = false
for i = 0, conf.cache_max_keys do
local prefix = conf.cache_prefix .. string.rep("X", i)
local exists = redis.call('EXISTS', prefix)

if exists then
if expired then
redis.call('DEL', prefix)
redis.call('HSET', prefix, cache_id, is_spam)

-- Do not expire anything else
expired = true
elseif i > 0 then
-- Move key to a shorter prefix, so we will rotate them eventually from lower to upper
local new_prefix = conf.cache_prefix .. string.rep("X", i - 1)
redis.call('RENAME', prefix, new_prefix)
end
end
end
end

return true

+ 1
- 1
src/libstat/CMakeLists.txt 查看文件

@@ -15,7 +15,7 @@ SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c
${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx)

SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c
${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.c)
${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx)

SET(RSPAMD_STAT ${LIBSTATSRC}
${TOKENIZERSSRC}

+ 2
- 2
src/libstat/backends/redis_backend.cxx 查看文件

@@ -1,5 +1,5 @@
/*
* Copyright 2023 Vsevolod Stakhov
* Copyright 2024 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -1032,7 +1032,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
lua_pushcclosure(L, &rspamd_redis_learned, 1);

if (lua_pcall(L, nargs, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
msg_err_task("call to script failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return FALSE;
}

+ 0
- 535
src/libstat/learn_cache/redis_cache.c 查看文件

@@ -1,535 +0,0 @@
/*
* Copyright 2023 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "config.h"
#include "learn_cache.h"
#include "rspamd.h"
#include "stat_api.h"
#include "stat_internal.h"
#include "cryptobox.h"
#include "ucl.h"
#include "hiredis.h"
#include "adapters/libev.h"
#include "lua/lua_common.h"
#include "libmime/message.h"

#define REDIS_DEFAULT_TIMEOUT 0.5
#define REDIS_STAT_TIMEOUT 30
#define REDIS_DEFAULT_PORT 6379
#define DEFAULT_REDIS_KEY "learned_ids"

static const gchar *M = "redis learn cache";

struct rspamd_redis_cache_ctx {
lua_State *L;
struct rspamd_statfile_config *stcf;
const gchar *username;
const gchar *password;
const gchar *dbname;
const gchar *redis_object;
gdouble timeout;
gint conf_ref;
};

struct rspamd_redis_cache_runtime {
struct rspamd_redis_cache_ctx *ctx;
struct rspamd_task *task;
struct upstream *selected;
ev_timer timer_ev;
redisAsyncContext *redis;
gboolean has_event;
};

static GQuark
rspamd_stat_cache_redis_quark(void)
{
return g_quark_from_static_string(M);
}

static inline struct upstream_list *
rspamd_redis_get_servers(struct rspamd_redis_cache_ctx *ctx,
const gchar *what)
{
lua_State *L = ctx->L;
struct upstream_list *res;

lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->conf_ref);
lua_pushstring(L, what);
lua_gettable(L, -2);
res = *((struct upstream_list **) lua_touserdata(L, -1));
lua_settop(L, 0);

return res;
}

static void
rspamd_redis_cache_maybe_auth(struct rspamd_redis_cache_ctx *ctx,
redisAsyncContext *redis)
{
if (ctx->username) {
if (ctx->password) {
redisAsyncCommand(redis, NULL, NULL, "AUTH %s %s", ctx->username, ctx->password);
}
else {
msg_warn("Redis requires a password when username is supplied");
}
}
else if (ctx->password) {
redisAsyncCommand(redis, NULL, NULL, "AUTH %s", ctx->password);
}
if (ctx->dbname) {
redisAsyncCommand(redis, NULL, NULL, "SELECT %s", ctx->dbname);
}
}

/* Called on connection termination */
static void
rspamd_redis_cache_fin(gpointer data)
{
struct rspamd_redis_cache_runtime *rt = data;
redisAsyncContext *redis;

rt->has_event = FALSE;
ev_timer_stop(rt->task->event_loop, &rt->timer_ev);

if (rt->redis) {
redis = rt->redis;
rt->redis = NULL;
/* This calls for all callbacks pending */
redisAsyncFree(redis);
}
}

static void
rspamd_redis_cache_timeout(EV_P_ ev_timer *w, int revents)
{
struct rspamd_redis_cache_runtime *rt =
(struct rspamd_redis_cache_runtime *) w->data;
struct rspamd_task *task;

task = rt->task;

msg_err_task("connection to redis server %s timed out",
rspamd_upstream_name(rt->selected));
rspamd_upstream_fail(rt->selected, FALSE, "timeout");

if (rt->has_event) {
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
}
}

/* Called when we have checked the specified message id */
static void
rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv)
{
struct rspamd_redis_cache_runtime *rt = priv;
redisReply *reply = r;
struct rspamd_task *task;
glong val = 0;

task = rt->task;

if (c->err == 0) {
if (reply) {
if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) {
val = reply->integer;
}
else if (reply->type == REDIS_REPLY_STRING) {
rspamd_strtol(reply->str, reply->len, &val);
}
else {
if (reply->type == REDIS_REPLY_ERROR) {
msg_err_task("cannot learn %s: redis error: \"%s\"",
rt->ctx->stcf->symbol, reply->str);
}
else if (reply->type != REDIS_REPLY_NIL) {
msg_err_task("bad learned type for %s: %d",
rt->ctx->stcf->symbol, reply->type);
}

val = 0;
}
}

if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
(val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
/* Already learned */
msg_info_task("<%s> has been already "
"learned as %s, ignore it",
MESSAGE_FIELD(task, message_id),
(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
}
else if (val != 0) {
/* Unlearn flag */
task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
}

rspamd_upstream_ok(rt->selected);
}
else {
rspamd_upstream_fail(rt->selected, FALSE, c->errstr);
}

if (rt->has_event) {
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
}
}

/* Called when we have learned the specified message id */
static void
rspamd_stat_cache_redis_set(redisAsyncContext *c, gpointer r, gpointer priv)
{
struct rspamd_redis_cache_runtime *rt = priv;
struct rspamd_task *task;

task = rt->task;

if (c->err == 0) {
/* XXX: we ignore results here */
rspamd_upstream_ok(rt->selected);
}
else {
rspamd_upstream_fail(rt->selected, FALSE, c->errstr);
}

if (rt->has_event) {
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
}
}

static void
rspamd_stat_cache_redis_generate_id(struct rspamd_task *task)
{
rspamd_cryptobox_hash_state_t st;
rspamd_token_t *tok;
guint i;
guchar out[rspamd_cryptobox_HASHBYTES];
gchar *b32out;
gchar *user = NULL;

rspamd_cryptobox_hash_init(&st, NULL, 0);

user = rspamd_mempool_get_variable(task->task_pool, "stat_user");
/* Use dedicated hash space for per users cache */
if (user != NULL) {
rspamd_cryptobox_hash_update(&st, user, strlen(user));
}

for (i = 0; i < task->tokens->len; i++) {
tok = g_ptr_array_index(task->tokens, i);
rspamd_cryptobox_hash_update(&st, (guchar *) &tok->data,
sizeof(tok->data));
}

rspamd_cryptobox_hash_final(&st, out);

b32out = rspamd_mempool_alloc(task->task_pool,
sizeof(out) * 8 / 5 + 3);
i = rspamd_encode_base32_buf(out, sizeof(out), b32out,
sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);

if (i > 0) {
/* Zero terminate */
b32out[i] = '\0';
}

rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL);
}

gpointer
rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
struct rspamd_config *cfg,
struct rspamd_statfile *st,
const ucl_object_t *cf)
{
struct rspamd_redis_cache_ctx *cache_ctx;
struct rspamd_statfile_config *stf = st->stcf;
const ucl_object_t *obj;
gboolean ret = FALSE;
lua_State *L = (lua_State *) cfg->lua_state;
gint conf_ref = -1;

cache_ctx = g_malloc0(sizeof(*cache_ctx));
cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
cache_ctx->L = L;

/* First search in backend configuration */
obj = ucl_object_lookup(st->classifier->cfg->opts, "backend");
if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) {
ret = rspamd_lua_try_load_redis(L, obj, cfg, &conf_ref);
}

/* Now try statfiles config */
if (!ret && stf->opts) {
ret = rspamd_lua_try_load_redis(L, stf->opts, cfg, &conf_ref);
}

/* Now try classifier config */
if (!ret && st->classifier->cfg->opts) {
ret = rspamd_lua_try_load_redis(L, st->classifier->cfg->opts, cfg, &conf_ref);
}

/* Now try global redis settings */
if (!ret) {
obj = ucl_object_lookup(cfg->cfg_ucl_obj, "redis");

if (obj) {
const ucl_object_t *specific_obj;

specific_obj = ucl_object_lookup(obj, "statistics");

if (specific_obj) {
ret = rspamd_lua_try_load_redis(L,
specific_obj, cfg, &conf_ref);
}
else {
ret = rspamd_lua_try_load_redis(L,
obj, cfg, &conf_ref);
}
}
}

if (!ret) {
msg_err_config("cannot init redis cache for %s", stf->symbol);
g_free(cache_ctx);
return NULL;
}

obj = ucl_object_lookup(st->classifier->cfg->opts, "cache_key");

if (obj) {
cache_ctx->redis_object = ucl_object_tostring(obj);
}
else {
cache_ctx->redis_object = DEFAULT_REDIS_KEY;
}

cache_ctx->conf_ref = conf_ref;

/* Check some common table values */
lua_rawgeti(L, LUA_REGISTRYINDEX, conf_ref);

lua_pushstring(L, "timeout");
lua_gettable(L, -2);
if (lua_type(L, -1) == LUA_TNUMBER) {
cache_ctx->timeout = lua_tonumber(L, -1);
}
lua_pop(L, 1);

lua_pushstring(L, "db");
lua_gettable(L, -2);
if (lua_type(L, -1) == LUA_TSTRING) {
cache_ctx->dbname = rspamd_mempool_strdup(cfg->cfg_pool,
lua_tostring(L, -1));
}
lua_pop(L, 1);

lua_pushstring(L, "username");
lua_gettable(L, -2);
if (lua_type(L, -1) == LUA_TSTRING) {
cache_ctx->username = rspamd_mempool_strdup(cfg->cfg_pool,
lua_tostring(L, -1));
}
lua_pop(L, 1);

lua_pushstring(L, "password");
lua_gettable(L, -2);
if (lua_type(L, -1) == LUA_TSTRING) {
cache_ctx->password = rspamd_mempool_strdup(cfg->cfg_pool,
lua_tostring(L, -1));
}
lua_pop(L, 1);

lua_settop(L, 0);

cache_ctx->stcf = stf;

return (gpointer) cache_ctx;
}

gpointer
rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
gpointer c, gboolean learn)
{
struct rspamd_redis_cache_ctx *ctx = c;
struct rspamd_redis_cache_runtime *rt;
struct upstream *up;
struct upstream_list *ups;
rspamd_inet_addr_t *addr;

g_assert(ctx != NULL);

if (task->tokens == NULL || task->tokens->len == 0) {
return NULL;
}

if (learn) {
ups = rspamd_redis_get_servers(ctx, "write_servers");

if (!ups) {
msg_err_task("no write servers defined for %s, cannot learn",
ctx->stcf->symbol);
return NULL;
}

up = rspamd_upstream_get(ups,
RSPAMD_UPSTREAM_MASTER_SLAVE,
NULL,
0);
}
else {
ups = rspamd_redis_get_servers(ctx, "read_servers");

if (!ups) {
msg_err_task("no read servers defined for %s, cannot check",
ctx->stcf->symbol);
return NULL;
}

up = rspamd_upstream_get(ups,
RSPAMD_UPSTREAM_ROUND_ROBIN,
NULL,
0);
}

if (up == NULL) {
msg_err_task("no upstreams reachable");
return NULL;
}

rt = rspamd_mempool_alloc0(task->task_pool, sizeof(*rt));
rt->selected = up;
rt->task = task;
rt->ctx = ctx;

addr = rspamd_upstream_addr_next(up);
g_assert(addr != NULL);

if (rspamd_inet_address_get_af(addr) == AF_UNIX) {
rt->redis = redisAsyncConnectUnix(rspamd_inet_address_to_string(addr));
}
else {
rt->redis = redisAsyncConnect(rspamd_inet_address_to_string(addr),
rspamd_inet_address_get_port(addr));
}

if (rt->redis == NULL) {
msg_warn_task("cannot connect to redis server %s: %s",
rspamd_inet_address_to_string_pretty(addr),
strerror(errno));

return NULL;
}
else if (rt->redis->err != REDIS_OK) {
msg_warn_task("cannot connect to redis server %s: %s",
rspamd_inet_address_to_string_pretty(addr),
rt->redis->errstr);
redisAsyncFree(rt->redis);
rt->redis = NULL;

return NULL;
}

redisLibevAttach(task->event_loop, rt->redis);

/* Now check stats */
rt->timer_ev.data = rt;
ev_timer_init(&rt->timer_ev, rspamd_redis_cache_timeout,
rt->ctx->timeout, 0.0);
rspamd_redis_cache_maybe_auth(ctx, rt->redis);

if (!learn) {
rspamd_stat_cache_redis_generate_id(task);
}

return rt;
}

gint rspamd_stat_cache_redis_check(struct rspamd_task *task,
gboolean is_spam,
gpointer runtime)
{
struct rspamd_redis_cache_runtime *rt = runtime;
gchar *h;

if (rspamd_session_blocked(task->s)) {
return RSPAMD_LEARN_IGNORE;
}

h = rspamd_mempool_get_variable(task->task_pool, "words_hash");

if (h == NULL) {
return RSPAMD_LEARN_IGNORE;
}

if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_get, rt,
"HGET %s %s",
rt->ctx->redis_object, h) == REDIS_OK) {
rspamd_session_add_event(task->s,
rspamd_redis_cache_fin,
rt,
M);
ev_timer_start(rt->task->event_loop, &rt->timer_ev);
rt->has_event = TRUE;
}

/* We need to return OK every time */
return RSPAMD_LEARN_OK;
}

gint rspamd_stat_cache_redis_learn(struct rspamd_task *task,
gboolean is_spam,
gpointer runtime)
{
struct rspamd_redis_cache_runtime *rt = runtime;
gchar *h;
gint flag;

if (rt == NULL || rt->ctx == NULL || rspamd_session_blocked(task->s)) {
return RSPAMD_LEARN_IGNORE;
}

h = rspamd_mempool_get_variable(task->task_pool, "words_hash");
g_assert(h != NULL);

flag = (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? 1 : -1;

if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_set, rt,
"HSET %s %s %d",
rt->ctx->redis_object, h, flag) == REDIS_OK) {
rspamd_session_add_event(task->s,
rspamd_redis_cache_fin, rt, M);
ev_timer_start(rt->task->event_loop, &rt->timer_ev);
rt->has_event = TRUE;
}

/* We need to return OK every time */
return RSPAMD_LEARN_OK;
}

void rspamd_stat_cache_redis_close(gpointer c)
{
struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *) c;
lua_State *L;

L = ctx->L;

if (ctx->conf_ref) {
luaL_unref(L, LUA_REGISTRYINDEX, ctx->conf_ref);
}

g_free(ctx);
}

+ 315
- 0
src/libstat/learn_cache/redis_cache.cxx 查看文件

@@ -0,0 +1,315 @@
/*
* Copyright 2024 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "config.h"
// Include early to avoid `extern "C"` issues
#include "lua/lua_common.h"
#include "learn_cache.h"
#include "rspamd.h"
#include "stat_api.h"
#include "stat_internal.h"
#include "cryptobox.h"
#include "ucl.h"
#include "libmime/message.h"

#define DEFAULT_REDIS_KEY "learned_ids"

struct rspamd_redis_cache_ctx {
lua_State *L;
struct rspamd_statfile_config *stcf;
std::string redis_object = DEFAULT_REDIS_KEY;
int check_ref = -1;
int learn_ref = -1;

rspamd_redis_cache_ctx() = delete;
explicit rspamd_redis_cache_ctx(lua_State *L)
: L(L)
{
}

~rspamd_redis_cache_ctx()
{
if (check_ref != -1) {
luaL_unref(L, LUA_REGISTRYINDEX, check_ref);
}

if (learn_ref != -1) {
luaL_unref(L, LUA_REGISTRYINDEX, learn_ref);
}
}
};

#if 0
/* Called when we have checked the specified message id */
static void
rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv)
{
struct rspamd_redis_cache_runtime *rt = priv;
redisReply *reply = r;
struct rspamd_task *task;
glong val = 0;

task = rt->task;

if (c->err == 0) {
if (reply) {
if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) {
val = reply->integer;
}
else if (reply->type == REDIS_REPLY_STRING) {
rspamd_strtol(reply->str, reply->len, &val);
}
else {
if (reply->type == REDIS_REPLY_ERROR) {
msg_err_task("cannot learn %s: redis error: \"%s\"",
rt->ctx->stcf->symbol, reply->str);
}
else if (reply->type != REDIS_REPLY_NIL) {
msg_err_task("bad learned type for %s: %d",
rt->ctx->stcf->symbol, reply->type);
}

val = 0;
}
}

if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
(val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
/* Already learned */
msg_info_task("<%s> has been already "
"learned as %s, ignore it",
MESSAGE_FIELD(task, message_id),
(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
}
else if (val != 0) {
/* Unlearn flag */
task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
}

rspamd_upstream_ok(rt->selected);
}
else {
rspamd_upstream_fail(rt->selected, FALSE, c->errstr);
}

if (rt->has_event) {
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt);
}
}
#endif

static void
rspamd_stat_cache_redis_generate_id(struct rspamd_task *task)
{
rspamd_cryptobox_hash_state_t st;
rspamd_cryptobox_hash_init(&st, NULL, 0);

const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user");
/* Use dedicated hash space for per users cache */
if (user != NULL) {
rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user));
}

for (auto i = 0; i < task->tokens->len; i++) {
const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i);
rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data,
sizeof(tok->data));
}

guchar out[rspamd_cryptobox_HASHBYTES];
rspamd_cryptobox_hash_final(&st, out);

auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool,
sizeof(out) * 8 / 5 + 3, char);
auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out,
sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);

if (out_sz > 0) {
/* Zero terminate */
b32out[out_sz] = '\0';
rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL);
}
}

gpointer
rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
struct rspamd_config *cfg,
struct rspamd_statfile *st,
const ucl_object_t *cf)
{
std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg));

auto *L = RSPAMD_LUA_CFG_STATE(cfg);
lua_settop(L, 0);

lua_pushcfunction(L, &rspamd_lua_traceback);
auto err_idx = lua_gettop(L);

/* Obtain function */
if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) {
msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache");
lua_settop(L, err_idx - 1);

return nullptr;
}

/* Push arguments */
ucl_object_push_lua(L, st->classifier->cfg->opts, false);
ucl_object_push_lua(L, st->stcf->opts, false);

if (lua_pcall(L, 2, 2, err_idx) != 0) {
msg_err("call to lua_bayes_init_cache "
"script failed: %s",
lua_tostring(L, -1));
lua_settop(L, err_idx - 1);

return nullptr;
}

/*
* Results are in the stack:
* top - 1 - check function (idx = -2)
* top - learn function (idx = -1)
*/
lua_pushvalue(L, -2);
cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX);

lua_pushvalue(L, -1);
cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX);

lua_settop(L, err_idx - 1);

return (gpointer) cache_ctx.release();
}

gpointer
rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
gpointer c, gboolean learn)
{
auto *ctx = (struct rspamd_redis_cache_ctx *) c;

if (task->tokens == NULL || task->tokens->len == 0) {
return NULL;
}

if (!learn) {
/* On check, we produce words_hash variable, on learn it is guaranteed to be set */
rspamd_stat_cache_redis_generate_id(task);
}

return (void *) ctx;
}

static gint
rspamd_stat_cache_checked(lua_State *L)
{
auto *task = lua_check_task(L, 1);
auto res = lua_toboolean(L, 2);

if (res) {
auto val = lua_tointeger(L, 3);

if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
(val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
/* Already learned */
msg_info_task("<%s> has been already "
"learned as %s, ignore it",
MESSAGE_FIELD(task, message_id),
(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
}
else if (val != 0) {
/* Unlearn flag */
task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
}
}

/* Ignore errors for now, as we can do nothing about them at the moment */

return 0;
}

gint rspamd_stat_cache_redis_check(struct rspamd_task *task,
gboolean is_spam,
gpointer runtime)
{
auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");

if (h == NULL) {
return RSPAMD_LEARN_IGNORE;
}

auto *L = ctx->L;

lua_pushcfunction(L, &rspamd_lua_traceback);
gint err_idx = lua_gettop(L);

/* Function arguments */
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref);
rspamd_lua_task_push(L, task);
lua_pushstring(L, h);

lua_pushcclosure(L, &rspamd_stat_cache_checked, 0);

if (lua_pcall(L, 3, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return RSPAMD_LEARN_IGNORE;
}

/* We need to return OK every time */
return RSPAMD_LEARN_OK;
}

gint rspamd_stat_cache_redis_learn(struct rspamd_task *task,
gboolean is_spam,
gpointer runtime)
{
auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;

if (rspamd_session_blocked(task->s)) {
return RSPAMD_LEARN_IGNORE;
}

auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
g_assert(h != NULL);
auto *L = ctx->L;

lua_pushcfunction(L, &rspamd_lua_traceback);
gint err_idx = lua_gettop(L);

/* Function arguments */
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
rspamd_lua_task_push(L, task);
lua_pushstring(L, h);
lua_pushboolean(L, is_spam);

if (lua_pcall(L, 3, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return RSPAMD_LEARN_IGNORE;
}

/* We need to return OK every time */
return RSPAMD_LEARN_OK;
}

void rspamd_stat_cache_redis_close(gpointer c)
{
auto *ctx = (struct rspamd_redis_cache_ctx *) c;
delete ctx;
}

Loading…
取消
儲存