Rewrite redis_cache logic in statisticstags/3.8.0
@@ -20,6 +20,7 @@ local exports = {} | |||
local lua_redis = require "lua_redis" | |||
local logger = require "rspamd_logger" | |||
local lua_util = require "lua_util" | |||
local ucl = require "ucl" | |||
local N = "bayes" | |||
@@ -54,24 +55,19 @@ local function gen_learn_functor(redis_params, learn_script_id) | |||
if maybe_text_tokens then | |||
lua_redis.exec_redis_script(learn_script_id, | |||
{ task = task, is_write = false, key = expanded_key }, | |||
{ task = task, is_write = true, key = expanded_key }, | |||
learn_redis_cb, | |||
{ expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) | |||
else | |||
lua_redis.exec_redis_script(learn_script_id, | |||
{ task = task, is_write = false, key = expanded_key }, | |||
{ task = task, is_write = true, key = expanded_key }, | |||
learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) | |||
end | |||
end | |||
end | |||
--- | |||
--- Init bayes classifier | |||
--- @param classifier_ucl ucl of the classifier config | |||
--- @param statfile_ucl ucl of the statfile config | |||
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error | |||
exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb) | |||
local function load_redis_params(classifier_ucl, statfile_ucl) | |||
local redis_params | |||
-- Try load from statfile options | |||
@@ -108,6 +104,22 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, | |||
return nil | |||
end | |||
return redis_params | |||
end | |||
--- | |||
--- Init bayes classifier | |||
--- @param classifier_ucl ucl of the classifier config | |||
--- @param statfile_ucl ucl of the statfile config | |||
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error | |||
exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb) | |||
local redis_params = load_redis_params(classifier_ucl, statfile_ucl) | |||
if not redis_params then | |||
return nil | |||
end | |||
local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params) | |||
local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params) | |||
local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params) | |||
@@ -125,7 +137,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, | |||
rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _) | |||
local function stat_redis_cb(err, data) | |||
-- TODO: write this function | |||
lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data) | |||
if err then | |||
@@ -161,4 +172,73 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, | |||
return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id) | |||
end | |||
local function gen_cache_check_functor(redis_params, check_script_id, conf) | |||
local packed_conf = ucl.to_format(conf, 'msgpack') | |||
return function(task, cache_id, callback) | |||
local function classify_redis_cb(err, data) | |||
lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data)) | |||
if err then | |||
callback(task, false, err) | |||
else | |||
if type(data) == 'number' then | |||
callback(task, true, data) | |||
else | |||
callback(task, false, 'not found') | |||
end | |||
end | |||
end | |||
lua_util.debugm(N, task, 'checking cache: %s', cache_id) | |||
lua_redis.exec_redis_script(check_script_id, | |||
{ task = task, is_write = false, key = cache_id }, | |||
classify_redis_cb, { cache_id, packed_conf }) | |||
end | |||
end | |||
local function gen_cache_learn_functor(redis_params, learn_script_id, conf) | |||
local packed_conf = ucl.to_format(conf, 'msgpack') | |||
return function(task, cache_id, is_spam) | |||
local function learn_redis_cb(err, data) | |||
lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data) | |||
end | |||
lua_util.debugm(N, task, 'try to learn cache: %s', cache_id) | |||
lua_redis.exec_redis_script(learn_script_id, | |||
{ task = task, is_write = true, key = cache_id }, | |||
learn_redis_cb, | |||
{ cache_id, is_spam and "1" or "0", packed_conf }) | |||
end | |||
end | |||
exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl) | |||
local redis_params = load_redis_params(classifier_ucl, statfile_ucl) | |||
if not redis_params then | |||
return nil | |||
end | |||
local default_conf = { | |||
cache_prefix = "learned_ids", | |||
cache_max_elt = 10000, -- Maximum number of elements in the cache key | |||
cache_max_keys = 5, -- Maximum number of keys in the cache | |||
cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value) | |||
} | |||
local conf = lua_util.override_defaults(default_conf, classifier_ucl) | |||
-- Clean all not known configurations | |||
for k, _ in pairs(conf) do | |||
if default_conf[k] == nil then | |||
conf[k] = nil | |||
end | |||
end | |||
local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params) | |||
local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params) | |||
return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params, | |||
learn_script_id, conf) | |||
end | |||
return exports |
@@ -0,0 +1,20 @@ | |||
-- Lua script to perform cache checking for bayes classification | |||
-- This script accepts the following parameters: | |||
-- key1 - cache id | |||
-- key2 - configuration table in message pack | |||
local cache_id = KEYS[1] | |||
local conf = cmsgpack.unpack(KEYS[2]) | |||
cache_id = string.sub(cache_id, 1, conf.cache_elt_len) | |||
-- Try each prefix that is in Redis | |||
for i = 0, conf.cache_max_keys do | |||
local prefix = conf.cache_prefix .. string.rep("X", i) | |||
local have = redis.call('HGET', prefix, cache_id) | |||
if have then | |||
return tonumber(have) | |||
end | |||
end | |||
return nil |
@@ -0,0 +1,61 @@ | |||
-- Lua script to perform cache checking for bayes classification | |||
-- This script accepts the following parameters: | |||
-- key1 - cache id | |||
-- key3 - is spam (1 or 0) | |||
-- key3 - configuration table in message pack | |||
local cache_id = KEYS[1] | |||
local is_spam = KEYS[2] | |||
local conf = cmsgpack.unpack(KEYS[3]) | |||
cache_id = string.sub(cache_id, 1, conf.cache_elt_len) | |||
-- Try each prefix that is in Redis (as some other instance might have set it) | |||
for i = 0, conf.cache_max_keys do | |||
local prefix = conf.cache_prefix .. string.rep("X", i) | |||
local have = redis.call('HGET', prefix, cache_id) | |||
if have then | |||
-- Already in cache | |||
return false | |||
end | |||
end | |||
local added = false | |||
local lim = conf.cache_max_elt | |||
for i = 0, conf.cache_max_keys do | |||
if not added then | |||
local prefix = conf.cache_prefix .. string.rep("X", i) | |||
local count = redis.call('HLEN', prefix) | |||
if count < lim then | |||
-- We can add it to this prefix | |||
redis.call('HSET', prefix, cache_id, is_spam) | |||
added = true | |||
end | |||
end | |||
end | |||
if not added then | |||
-- Need to expire some keys | |||
local expired = false | |||
for i = 0, conf.cache_max_keys do | |||
local prefix = conf.cache_prefix .. string.rep("X", i) | |||
local exists = redis.call('EXISTS', prefix) | |||
if exists then | |||
if expired then | |||
redis.call('DEL', prefix) | |||
redis.call('HSET', prefix, cache_id, is_spam) | |||
-- Do not expire anything else | |||
expired = true | |||
elseif i > 0 then | |||
-- Move key to a shorter prefix, so we will rotate them eventually from lower to upper | |||
local new_prefix = conf.cache_prefix .. string.rep("X", i - 1) | |||
redis.call('RENAME', prefix, new_prefix) | |||
end | |||
end | |||
end | |||
end | |||
return true |
@@ -15,7 +15,7 @@ SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c | |||
${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) | |||
SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c | |||
${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.c) | |||
${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx) | |||
SET(RSPAMD_STAT ${LIBSTATSRC} | |||
${TOKENIZERSSRC} |
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright 2023 Vsevolod Stakhov | |||
* Copyright 2024 Vsevolod Stakhov | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -1032,7 +1032,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, | |||
lua_pushcclosure(L, &rspamd_redis_learned, 1); | |||
if (lua_pcall(L, nargs, 0, err_idx) != 0) { | |||
msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); | |||
msg_err_task("call to script failed: %s", lua_tostring(L, -1)); | |||
lua_settop(L, err_idx - 1); | |||
return FALSE; | |||
} |
@@ -1,535 +0,0 @@ | |||
/* | |||
* Copyright 2023 Vsevolod Stakhov | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "config.h" | |||
#include "learn_cache.h" | |||
#include "rspamd.h" | |||
#include "stat_api.h" | |||
#include "stat_internal.h" | |||
#include "cryptobox.h" | |||
#include "ucl.h" | |||
#include "hiredis.h" | |||
#include "adapters/libev.h" | |||
#include "lua/lua_common.h" | |||
#include "libmime/message.h" | |||
#define REDIS_DEFAULT_TIMEOUT 0.5 | |||
#define REDIS_STAT_TIMEOUT 30 | |||
#define REDIS_DEFAULT_PORT 6379 | |||
#define DEFAULT_REDIS_KEY "learned_ids" | |||
static const gchar *M = "redis learn cache"; | |||
struct rspamd_redis_cache_ctx { | |||
lua_State *L; | |||
struct rspamd_statfile_config *stcf; | |||
const gchar *username; | |||
const gchar *password; | |||
const gchar *dbname; | |||
const gchar *redis_object; | |||
gdouble timeout; | |||
gint conf_ref; | |||
}; | |||
struct rspamd_redis_cache_runtime { | |||
struct rspamd_redis_cache_ctx *ctx; | |||
struct rspamd_task *task; | |||
struct upstream *selected; | |||
ev_timer timer_ev; | |||
redisAsyncContext *redis; | |||
gboolean has_event; | |||
}; | |||
static GQuark | |||
rspamd_stat_cache_redis_quark(void) | |||
{ | |||
return g_quark_from_static_string(M); | |||
} | |||
static inline struct upstream_list * | |||
rspamd_redis_get_servers(struct rspamd_redis_cache_ctx *ctx, | |||
const gchar *what) | |||
{ | |||
lua_State *L = ctx->L; | |||
struct upstream_list *res; | |||
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->conf_ref); | |||
lua_pushstring(L, what); | |||
lua_gettable(L, -2); | |||
res = *((struct upstream_list **) lua_touserdata(L, -1)); | |||
lua_settop(L, 0); | |||
return res; | |||
} | |||
static void | |||
rspamd_redis_cache_maybe_auth(struct rspamd_redis_cache_ctx *ctx, | |||
redisAsyncContext *redis) | |||
{ | |||
if (ctx->username) { | |||
if (ctx->password) { | |||
redisAsyncCommand(redis, NULL, NULL, "AUTH %s %s", ctx->username, ctx->password); | |||
} | |||
else { | |||
msg_warn("Redis requires a password when username is supplied"); | |||
} | |||
} | |||
else if (ctx->password) { | |||
redisAsyncCommand(redis, NULL, NULL, "AUTH %s", ctx->password); | |||
} | |||
if (ctx->dbname) { | |||
redisAsyncCommand(redis, NULL, NULL, "SELECT %s", ctx->dbname); | |||
} | |||
} | |||
/* Called on connection termination */ | |||
static void | |||
rspamd_redis_cache_fin(gpointer data) | |||
{ | |||
struct rspamd_redis_cache_runtime *rt = data; | |||
redisAsyncContext *redis; | |||
rt->has_event = FALSE; | |||
ev_timer_stop(rt->task->event_loop, &rt->timer_ev); | |||
if (rt->redis) { | |||
redis = rt->redis; | |||
rt->redis = NULL; | |||
/* This calls for all callbacks pending */ | |||
redisAsyncFree(redis); | |||
} | |||
} | |||
static void | |||
rspamd_redis_cache_timeout(EV_P_ ev_timer *w, int revents) | |||
{ | |||
struct rspamd_redis_cache_runtime *rt = | |||
(struct rspamd_redis_cache_runtime *) w->data; | |||
struct rspamd_task *task; | |||
task = rt->task; | |||
msg_err_task("connection to redis server %s timed out", | |||
rspamd_upstream_name(rt->selected)); | |||
rspamd_upstream_fail(rt->selected, FALSE, "timeout"); | |||
if (rt->has_event) { | |||
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); | |||
} | |||
} | |||
/* Called when we have checked the specified message id */ | |||
static void | |||
rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv) | |||
{ | |||
struct rspamd_redis_cache_runtime *rt = priv; | |||
redisReply *reply = r; | |||
struct rspamd_task *task; | |||
glong val = 0; | |||
task = rt->task; | |||
if (c->err == 0) { | |||
if (reply) { | |||
if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) { | |||
val = reply->integer; | |||
} | |||
else if (reply->type == REDIS_REPLY_STRING) { | |||
rspamd_strtol(reply->str, reply->len, &val); | |||
} | |||
else { | |||
if (reply->type == REDIS_REPLY_ERROR) { | |||
msg_err_task("cannot learn %s: redis error: \"%s\"", | |||
rt->ctx->stcf->symbol, reply->str); | |||
} | |||
else if (reply->type != REDIS_REPLY_NIL) { | |||
msg_err_task("bad learned type for %s: %d", | |||
rt->ctx->stcf->symbol, reply->type); | |||
} | |||
val = 0; | |||
} | |||
} | |||
if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || | |||
(val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { | |||
/* Already learned */ | |||
msg_info_task("<%s> has been already " | |||
"learned as %s, ignore it", | |||
MESSAGE_FIELD(task, message_id), | |||
(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); | |||
task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; | |||
} | |||
else if (val != 0) { | |||
/* Unlearn flag */ | |||
task->flags |= RSPAMD_TASK_FLAG_UNLEARN; | |||
} | |||
rspamd_upstream_ok(rt->selected); | |||
} | |||
else { | |||
rspamd_upstream_fail(rt->selected, FALSE, c->errstr); | |||
} | |||
if (rt->has_event) { | |||
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); | |||
} | |||
} | |||
/* Called when we have learned the specified message id */ | |||
static void | |||
rspamd_stat_cache_redis_set(redisAsyncContext *c, gpointer r, gpointer priv) | |||
{ | |||
struct rspamd_redis_cache_runtime *rt = priv; | |||
struct rspamd_task *task; | |||
task = rt->task; | |||
if (c->err == 0) { | |||
/* XXX: we ignore results here */ | |||
rspamd_upstream_ok(rt->selected); | |||
} | |||
else { | |||
rspamd_upstream_fail(rt->selected, FALSE, c->errstr); | |||
} | |||
if (rt->has_event) { | |||
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); | |||
} | |||
} | |||
static void | |||
rspamd_stat_cache_redis_generate_id(struct rspamd_task *task) | |||
{ | |||
rspamd_cryptobox_hash_state_t st; | |||
rspamd_token_t *tok; | |||
guint i; | |||
guchar out[rspamd_cryptobox_HASHBYTES]; | |||
gchar *b32out; | |||
gchar *user = NULL; | |||
rspamd_cryptobox_hash_init(&st, NULL, 0); | |||
user = rspamd_mempool_get_variable(task->task_pool, "stat_user"); | |||
/* Use dedicated hash space for per users cache */ | |||
if (user != NULL) { | |||
rspamd_cryptobox_hash_update(&st, user, strlen(user)); | |||
} | |||
for (i = 0; i < task->tokens->len; i++) { | |||
tok = g_ptr_array_index(task->tokens, i); | |||
rspamd_cryptobox_hash_update(&st, (guchar *) &tok->data, | |||
sizeof(tok->data)); | |||
} | |||
rspamd_cryptobox_hash_final(&st, out); | |||
b32out = rspamd_mempool_alloc(task->task_pool, | |||
sizeof(out) * 8 / 5 + 3); | |||
i = rspamd_encode_base32_buf(out, sizeof(out), b32out, | |||
sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT); | |||
if (i > 0) { | |||
/* Zero terminate */ | |||
b32out[i] = '\0'; | |||
} | |||
rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL); | |||
} | |||
gpointer | |||
rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx, | |||
struct rspamd_config *cfg, | |||
struct rspamd_statfile *st, | |||
const ucl_object_t *cf) | |||
{ | |||
struct rspamd_redis_cache_ctx *cache_ctx; | |||
struct rspamd_statfile_config *stf = st->stcf; | |||
const ucl_object_t *obj; | |||
gboolean ret = FALSE; | |||
lua_State *L = (lua_State *) cfg->lua_state; | |||
gint conf_ref = -1; | |||
cache_ctx = g_malloc0(sizeof(*cache_ctx)); | |||
cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT; | |||
cache_ctx->L = L; | |||
/* First search in backend configuration */ | |||
obj = ucl_object_lookup(st->classifier->cfg->opts, "backend"); | |||
if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) { | |||
ret = rspamd_lua_try_load_redis(L, obj, cfg, &conf_ref); | |||
} | |||
/* Now try statfiles config */ | |||
if (!ret && stf->opts) { | |||
ret = rspamd_lua_try_load_redis(L, stf->opts, cfg, &conf_ref); | |||
} | |||
/* Now try classifier config */ | |||
if (!ret && st->classifier->cfg->opts) { | |||
ret = rspamd_lua_try_load_redis(L, st->classifier->cfg->opts, cfg, &conf_ref); | |||
} | |||
/* Now try global redis settings */ | |||
if (!ret) { | |||
obj = ucl_object_lookup(cfg->cfg_ucl_obj, "redis"); | |||
if (obj) { | |||
const ucl_object_t *specific_obj; | |||
specific_obj = ucl_object_lookup(obj, "statistics"); | |||
if (specific_obj) { | |||
ret = rspamd_lua_try_load_redis(L, | |||
specific_obj, cfg, &conf_ref); | |||
} | |||
else { | |||
ret = rspamd_lua_try_load_redis(L, | |||
obj, cfg, &conf_ref); | |||
} | |||
} | |||
} | |||
if (!ret) { | |||
msg_err_config("cannot init redis cache for %s", stf->symbol); | |||
g_free(cache_ctx); | |||
return NULL; | |||
} | |||
obj = ucl_object_lookup(st->classifier->cfg->opts, "cache_key"); | |||
if (obj) { | |||
cache_ctx->redis_object = ucl_object_tostring(obj); | |||
} | |||
else { | |||
cache_ctx->redis_object = DEFAULT_REDIS_KEY; | |||
} | |||
cache_ctx->conf_ref = conf_ref; | |||
/* Check some common table values */ | |||
lua_rawgeti(L, LUA_REGISTRYINDEX, conf_ref); | |||
lua_pushstring(L, "timeout"); | |||
lua_gettable(L, -2); | |||
if (lua_type(L, -1) == LUA_TNUMBER) { | |||
cache_ctx->timeout = lua_tonumber(L, -1); | |||
} | |||
lua_pop(L, 1); | |||
lua_pushstring(L, "db"); | |||
lua_gettable(L, -2); | |||
if (lua_type(L, -1) == LUA_TSTRING) { | |||
cache_ctx->dbname = rspamd_mempool_strdup(cfg->cfg_pool, | |||
lua_tostring(L, -1)); | |||
} | |||
lua_pop(L, 1); | |||
lua_pushstring(L, "username"); | |||
lua_gettable(L, -2); | |||
if (lua_type(L, -1) == LUA_TSTRING) { | |||
cache_ctx->username = rspamd_mempool_strdup(cfg->cfg_pool, | |||
lua_tostring(L, -1)); | |||
} | |||
lua_pop(L, 1); | |||
lua_pushstring(L, "password"); | |||
lua_gettable(L, -2); | |||
if (lua_type(L, -1) == LUA_TSTRING) { | |||
cache_ctx->password = rspamd_mempool_strdup(cfg->cfg_pool, | |||
lua_tostring(L, -1)); | |||
} | |||
lua_pop(L, 1); | |||
lua_settop(L, 0); | |||
cache_ctx->stcf = stf; | |||
return (gpointer) cache_ctx; | |||
} | |||
gpointer | |||
rspamd_stat_cache_redis_runtime(struct rspamd_task *task, | |||
gpointer c, gboolean learn) | |||
{ | |||
struct rspamd_redis_cache_ctx *ctx = c; | |||
struct rspamd_redis_cache_runtime *rt; | |||
struct upstream *up; | |||
struct upstream_list *ups; | |||
rspamd_inet_addr_t *addr; | |||
g_assert(ctx != NULL); | |||
if (task->tokens == NULL || task->tokens->len == 0) { | |||
return NULL; | |||
} | |||
if (learn) { | |||
ups = rspamd_redis_get_servers(ctx, "write_servers"); | |||
if (!ups) { | |||
msg_err_task("no write servers defined for %s, cannot learn", | |||
ctx->stcf->symbol); | |||
return NULL; | |||
} | |||
up = rspamd_upstream_get(ups, | |||
RSPAMD_UPSTREAM_MASTER_SLAVE, | |||
NULL, | |||
0); | |||
} | |||
else { | |||
ups = rspamd_redis_get_servers(ctx, "read_servers"); | |||
if (!ups) { | |||
msg_err_task("no read servers defined for %s, cannot check", | |||
ctx->stcf->symbol); | |||
return NULL; | |||
} | |||
up = rspamd_upstream_get(ups, | |||
RSPAMD_UPSTREAM_ROUND_ROBIN, | |||
NULL, | |||
0); | |||
} | |||
if (up == NULL) { | |||
msg_err_task("no upstreams reachable"); | |||
return NULL; | |||
} | |||
rt = rspamd_mempool_alloc0(task->task_pool, sizeof(*rt)); | |||
rt->selected = up; | |||
rt->task = task; | |||
rt->ctx = ctx; | |||
addr = rspamd_upstream_addr_next(up); | |||
g_assert(addr != NULL); | |||
if (rspamd_inet_address_get_af(addr) == AF_UNIX) { | |||
rt->redis = redisAsyncConnectUnix(rspamd_inet_address_to_string(addr)); | |||
} | |||
else { | |||
rt->redis = redisAsyncConnect(rspamd_inet_address_to_string(addr), | |||
rspamd_inet_address_get_port(addr)); | |||
} | |||
if (rt->redis == NULL) { | |||
msg_warn_task("cannot connect to redis server %s: %s", | |||
rspamd_inet_address_to_string_pretty(addr), | |||
strerror(errno)); | |||
return NULL; | |||
} | |||
else if (rt->redis->err != REDIS_OK) { | |||
msg_warn_task("cannot connect to redis server %s: %s", | |||
rspamd_inet_address_to_string_pretty(addr), | |||
rt->redis->errstr); | |||
redisAsyncFree(rt->redis); | |||
rt->redis = NULL; | |||
return NULL; | |||
} | |||
redisLibevAttach(task->event_loop, rt->redis); | |||
/* Now check stats */ | |||
rt->timer_ev.data = rt; | |||
ev_timer_init(&rt->timer_ev, rspamd_redis_cache_timeout, | |||
rt->ctx->timeout, 0.0); | |||
rspamd_redis_cache_maybe_auth(ctx, rt->redis); | |||
if (!learn) { | |||
rspamd_stat_cache_redis_generate_id(task); | |||
} | |||
return rt; | |||
} | |||
gint rspamd_stat_cache_redis_check(struct rspamd_task *task, | |||
gboolean is_spam, | |||
gpointer runtime) | |||
{ | |||
struct rspamd_redis_cache_runtime *rt = runtime; | |||
gchar *h; | |||
if (rspamd_session_blocked(task->s)) { | |||
return RSPAMD_LEARN_IGNORE; | |||
} | |||
h = rspamd_mempool_get_variable(task->task_pool, "words_hash"); | |||
if (h == NULL) { | |||
return RSPAMD_LEARN_IGNORE; | |||
} | |||
if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_get, rt, | |||
"HGET %s %s", | |||
rt->ctx->redis_object, h) == REDIS_OK) { | |||
rspamd_session_add_event(task->s, | |||
rspamd_redis_cache_fin, | |||
rt, | |||
M); | |||
ev_timer_start(rt->task->event_loop, &rt->timer_ev); | |||
rt->has_event = TRUE; | |||
} | |||
/* We need to return OK every time */ | |||
return RSPAMD_LEARN_OK; | |||
} | |||
gint rspamd_stat_cache_redis_learn(struct rspamd_task *task, | |||
gboolean is_spam, | |||
gpointer runtime) | |||
{ | |||
struct rspamd_redis_cache_runtime *rt = runtime; | |||
gchar *h; | |||
gint flag; | |||
if (rt == NULL || rt->ctx == NULL || rspamd_session_blocked(task->s)) { | |||
return RSPAMD_LEARN_IGNORE; | |||
} | |||
h = rspamd_mempool_get_variable(task->task_pool, "words_hash"); | |||
g_assert(h != NULL); | |||
flag = (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? 1 : -1; | |||
if (redisAsyncCommand(rt->redis, rspamd_stat_cache_redis_set, rt, | |||
"HSET %s %s %d", | |||
rt->ctx->redis_object, h, flag) == REDIS_OK) { | |||
rspamd_session_add_event(task->s, | |||
rspamd_redis_cache_fin, rt, M); | |||
ev_timer_start(rt->task->event_loop, &rt->timer_ev); | |||
rt->has_event = TRUE; | |||
} | |||
/* We need to return OK every time */ | |||
return RSPAMD_LEARN_OK; | |||
} | |||
void rspamd_stat_cache_redis_close(gpointer c) | |||
{ | |||
struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *) c; | |||
lua_State *L; | |||
L = ctx->L; | |||
if (ctx->conf_ref) { | |||
luaL_unref(L, LUA_REGISTRYINDEX, ctx->conf_ref); | |||
} | |||
g_free(ctx); | |||
} |
@@ -0,0 +1,315 @@ | |||
/* | |||
* Copyright 2024 Vsevolod Stakhov | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "config.h" | |||
// Include early to avoid `extern "C"` issues | |||
#include "lua/lua_common.h" | |||
#include "learn_cache.h" | |||
#include "rspamd.h" | |||
#include "stat_api.h" | |||
#include "stat_internal.h" | |||
#include "cryptobox.h" | |||
#include "ucl.h" | |||
#include "libmime/message.h" | |||
#define DEFAULT_REDIS_KEY "learned_ids" | |||
struct rspamd_redis_cache_ctx { | |||
lua_State *L; | |||
struct rspamd_statfile_config *stcf; | |||
std::string redis_object = DEFAULT_REDIS_KEY; | |||
int check_ref = -1; | |||
int learn_ref = -1; | |||
rspamd_redis_cache_ctx() = delete; | |||
explicit rspamd_redis_cache_ctx(lua_State *L) | |||
: L(L) | |||
{ | |||
} | |||
~rspamd_redis_cache_ctx() | |||
{ | |||
if (check_ref != -1) { | |||
luaL_unref(L, LUA_REGISTRYINDEX, check_ref); | |||
} | |||
if (learn_ref != -1) { | |||
luaL_unref(L, LUA_REGISTRYINDEX, learn_ref); | |||
} | |||
} | |||
}; | |||
#if 0 | |||
/* Called when we have checked the specified message id */ | |||
static void | |||
rspamd_stat_cache_redis_get(redisAsyncContext *c, gpointer r, gpointer priv) | |||
{ | |||
struct rspamd_redis_cache_runtime *rt = priv; | |||
redisReply *reply = r; | |||
struct rspamd_task *task; | |||
glong val = 0; | |||
task = rt->task; | |||
if (c->err == 0) { | |||
if (reply) { | |||
if (G_LIKELY(reply->type == REDIS_REPLY_INTEGER)) { | |||
val = reply->integer; | |||
} | |||
else if (reply->type == REDIS_REPLY_STRING) { | |||
rspamd_strtol(reply->str, reply->len, &val); | |||
} | |||
else { | |||
if (reply->type == REDIS_REPLY_ERROR) { | |||
msg_err_task("cannot learn %s: redis error: \"%s\"", | |||
rt->ctx->stcf->symbol, reply->str); | |||
} | |||
else if (reply->type != REDIS_REPLY_NIL) { | |||
msg_err_task("bad learned type for %s: %d", | |||
rt->ctx->stcf->symbol, reply->type); | |||
} | |||
val = 0; | |||
} | |||
} | |||
if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || | |||
(val < 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { | |||
/* Already learned */ | |||
msg_info_task("<%s> has been already " | |||
"learned as %s, ignore it", | |||
MESSAGE_FIELD(task, message_id), | |||
(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); | |||
task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; | |||
} | |||
else if (val != 0) { | |||
/* Unlearn flag */ | |||
task->flags |= RSPAMD_TASK_FLAG_UNLEARN; | |||
} | |||
rspamd_upstream_ok(rt->selected); | |||
} | |||
else { | |||
rspamd_upstream_fail(rt->selected, FALSE, c->errstr); | |||
} | |||
if (rt->has_event) { | |||
rspamd_session_remove_event(task->s, rspamd_redis_cache_fin, rt); | |||
} | |||
} | |||
#endif | |||
static void | |||
rspamd_stat_cache_redis_generate_id(struct rspamd_task *task) | |||
{ | |||
rspamd_cryptobox_hash_state_t st; | |||
rspamd_cryptobox_hash_init(&st, NULL, 0); | |||
const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user"); | |||
/* Use dedicated hash space for per users cache */ | |||
if (user != NULL) { | |||
rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user)); | |||
} | |||
for (auto i = 0; i < task->tokens->len; i++) { | |||
const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i); | |||
rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data, | |||
sizeof(tok->data)); | |||
} | |||
guchar out[rspamd_cryptobox_HASHBYTES]; | |||
rspamd_cryptobox_hash_final(&st, out); | |||
auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool, | |||
sizeof(out) * 8 / 5 + 3, char); | |||
auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out, | |||
sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT); | |||
if (out_sz > 0) { | |||
/* Zero terminate */ | |||
b32out[out_sz] = '\0'; | |||
rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, NULL); | |||
} | |||
} | |||
gpointer | |||
rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx, | |||
struct rspamd_config *cfg, | |||
struct rspamd_statfile *st, | |||
const ucl_object_t *cf) | |||
{ | |||
std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg)); | |||
auto *L = RSPAMD_LUA_CFG_STATE(cfg); | |||
lua_settop(L, 0); | |||
lua_pushcfunction(L, &rspamd_lua_traceback); | |||
auto err_idx = lua_gettop(L); | |||
/* Obtain function */ | |||
if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) { | |||
msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache"); | |||
lua_settop(L, err_idx - 1); | |||
return nullptr; | |||
} | |||
/* Push arguments */ | |||
ucl_object_push_lua(L, st->classifier->cfg->opts, false); | |||
ucl_object_push_lua(L, st->stcf->opts, false); | |||
if (lua_pcall(L, 2, 2, err_idx) != 0) { | |||
msg_err("call to lua_bayes_init_cache " | |||
"script failed: %s", | |||
lua_tostring(L, -1)); | |||
lua_settop(L, err_idx - 1); | |||
return nullptr; | |||
} | |||
/* | |||
* Results are in the stack: | |||
* top - 1 - check function (idx = -2) | |||
* top - learn function (idx = -1) | |||
*/ | |||
lua_pushvalue(L, -2); | |||
cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX); | |||
lua_pushvalue(L, -1); | |||
cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX); | |||
lua_settop(L, err_idx - 1); | |||
return (gpointer) cache_ctx.release(); | |||
} | |||
gpointer | |||
rspamd_stat_cache_redis_runtime(struct rspamd_task *task, | |||
gpointer c, gboolean learn) | |||
{ | |||
auto *ctx = (struct rspamd_redis_cache_ctx *) c; | |||
if (task->tokens == NULL || task->tokens->len == 0) { | |||
return NULL; | |||
} | |||
if (!learn) { | |||
/* On check, we produce words_hash variable, on learn it is guaranteed to be set */ | |||
rspamd_stat_cache_redis_generate_id(task); | |||
} | |||
return (void *) ctx; | |||
} | |||
static gint | |||
rspamd_stat_cache_checked(lua_State *L) | |||
{ | |||
auto *task = lua_check_task(L, 1); | |||
auto res = lua_toboolean(L, 2); | |||
if (res) { | |||
auto val = lua_tointeger(L, 3); | |||
if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || | |||
(val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { | |||
/* Already learned */ | |||
msg_info_task("<%s> has been already " | |||
"learned as %s, ignore it", | |||
MESSAGE_FIELD(task, message_id), | |||
(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); | |||
task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; | |||
} | |||
else if (val != 0) { | |||
/* Unlearn flag */ | |||
task->flags |= RSPAMD_TASK_FLAG_UNLEARN; | |||
} | |||
} | |||
/* Ignore errors for now, as we can do nothing about them at the moment */ | |||
return 0; | |||
} | |||
gint rspamd_stat_cache_redis_check(struct rspamd_task *task, | |||
gboolean is_spam, | |||
gpointer runtime) | |||
{ | |||
auto *ctx = (struct rspamd_redis_cache_ctx *) runtime; | |||
auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash"); | |||
if (h == NULL) { | |||
return RSPAMD_LEARN_IGNORE; | |||
} | |||
auto *L = ctx->L; | |||
lua_pushcfunction(L, &rspamd_lua_traceback); | |||
gint err_idx = lua_gettop(L); | |||
/* Function arguments */ | |||
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref); | |||
rspamd_lua_task_push(L, task); | |||
lua_pushstring(L, h); | |||
lua_pushcclosure(L, &rspamd_stat_cache_checked, 0); | |||
if (lua_pcall(L, 3, 0, err_idx) != 0) { | |||
msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); | |||
lua_settop(L, err_idx - 1); | |||
return RSPAMD_LEARN_IGNORE; | |||
} | |||
/* We need to return OK every time */ | |||
return RSPAMD_LEARN_OK; | |||
} | |||
gint rspamd_stat_cache_redis_learn(struct rspamd_task *task, | |||
gboolean is_spam, | |||
gpointer runtime) | |||
{ | |||
auto *ctx = (struct rspamd_redis_cache_ctx *) runtime; | |||
if (rspamd_session_blocked(task->s)) { | |||
return RSPAMD_LEARN_IGNORE; | |||
} | |||
auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash"); | |||
g_assert(h != NULL); | |||
auto *L = ctx->L; | |||
lua_pushcfunction(L, &rspamd_lua_traceback); | |||
gint err_idx = lua_gettop(L); | |||
/* Function arguments */ | |||
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref); | |||
rspamd_lua_task_push(L, task); | |||
lua_pushstring(L, h); | |||
lua_pushboolean(L, is_spam); | |||
if (lua_pcall(L, 3, 0, err_idx) != 0) { | |||
msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); | |||
lua_settop(L, err_idx - 1); | |||
return RSPAMD_LEARN_IGNORE; | |||
} | |||
/* We need to return OK every time */ | |||
return RSPAMD_LEARN_OK; | |||
} | |||
void rspamd_stat_cache_redis_close(gpointer c) | |||
{ | |||
auto *ctx = (struct rspamd_redis_cache_ctx *) c; | |||
delete ctx; | |||
} |