123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- /*
- * Copyright 2024 Vsevolod Stakhov
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "config.h"
- // Include early to avoid `extern "C"` issues
- #include "lua/lua_common.h"
- #include "learn_cache.h"
- #include "rspamd.h"
- #include "stat_api.h"
- #include "stat_internal.h"
- #include "cryptobox.h"
- #include "ucl.h"
- #include "libmime/message.h"
-
- #include <memory>
-
- struct rspamd_redis_cache_ctx {
- lua_State *L;
- struct rspamd_statfile_config *stcf;
- int check_ref = -1;
- int learn_ref = -1;
-
- rspamd_redis_cache_ctx() = delete;
- explicit rspamd_redis_cache_ctx(lua_State *L)
- : L(L)
- {
- }
-
- ~rspamd_redis_cache_ctx()
- {
- if (check_ref != -1) {
- luaL_unref(L, LUA_REGISTRYINDEX, check_ref);
- }
-
- if (learn_ref != -1) {
- luaL_unref(L, LUA_REGISTRYINDEX, learn_ref);
- }
- }
- };
-
- static void
- rspamd_stat_cache_redis_generate_id(struct rspamd_task *task)
- {
- rspamd_cryptobox_hash_state_t st;
- rspamd_cryptobox_hash_init(&st, nullptr, 0);
-
- const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user");
- /* Use dedicated hash space for per users cache */
- if (user != nullptr) {
- rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user));
- }
-
- for (auto i = 0; i < task->tokens->len; i++) {
- const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i);
- rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data,
- sizeof(tok->data));
- }
-
- unsigned char out[rspamd_cryptobox_HASHBYTES];
- rspamd_cryptobox_hash_final(&st, out);
-
- auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool,
- sizeof(out) * 8 / 5 + 3, char);
- auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out,
- sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);
-
- if (out_sz > 0) {
- /* Zero terminate */
- b32out[out_sz] = '\0';
- rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, nullptr);
- }
- }
-
- gpointer
- rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
- struct rspamd_config *cfg,
- struct rspamd_statfile *st,
- const ucl_object_t *cf)
- {
- std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg));
-
- auto *L = RSPAMD_LUA_CFG_STATE(cfg);
- lua_settop(L, 0);
-
- lua_pushcfunction(L, &rspamd_lua_traceback);
- auto err_idx = lua_gettop(L);
-
- /* Obtain function */
- if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) {
- msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache");
- lua_settop(L, err_idx - 1);
-
- return nullptr;
- }
-
- /* Push arguments */
- ucl_object_push_lua(L, st->classifier->cfg->opts, false);
- ucl_object_push_lua(L, st->stcf->opts, false);
-
- if (lua_pcall(L, 2, 2, err_idx) != 0) {
- msg_err("call to lua_bayes_init_cache "
- "script failed: %s",
- lua_tostring(L, -1));
- lua_settop(L, err_idx - 1);
-
- return nullptr;
- }
-
- /*
- * Results are in the stack:
- * top - 1 - check function (idx = -2)
- * top - learn function (idx = -1)
- */
- lua_pushvalue(L, -2);
- cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX);
-
- lua_pushvalue(L, -1);
- cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX);
-
- lua_settop(L, err_idx - 1);
-
- return (gpointer) cache_ctx.release();
- }
-
- gpointer
- rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
- gpointer c, gboolean learn)
- {
- auto *ctx = (struct rspamd_redis_cache_ctx *) c;
-
- if (task->tokens == nullptr || task->tokens->len == 0) {
- return nullptr;
- }
-
- if (!learn) {
- /* On check, we produce words_hash variable, on learn it is guaranteed to be set */
- rspamd_stat_cache_redis_generate_id(task);
- }
-
- return (void *) ctx;
- }
-
- static int
- rspamd_stat_cache_checked(lua_State *L)
- {
- auto *task = lua_check_task(L, 1);
- auto res = lua_toboolean(L, 2);
-
- if (res) {
- auto val = lua_tointeger(L, 3);
-
- if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
- (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
- /* Already learned */
- msg_info_task("<%s> has been already "
- "learned as %s, ignore it",
- MESSAGE_FIELD(task, message_id),
- (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
- task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
- }
- else {
- /* Unlearn flag */
- task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
- }
- }
-
- /* Ignore errors for now, as we can do nothing about them at the moment */
-
- return 0;
- }
-
- int rspamd_stat_cache_redis_check(struct rspamd_task *task,
- gboolean is_spam,
- gpointer runtime)
- {
- auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
- auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
-
- if (h == nullptr) {
- return RSPAMD_LEARN_IGNORE;
- }
-
- auto *L = ctx->L;
-
- lua_pushcfunction(L, &rspamd_lua_traceback);
- int err_idx = lua_gettop(L);
-
- /* Function arguments */
- lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref);
- rspamd_lua_task_push(L, task);
- lua_pushstring(L, h);
-
- lua_pushcclosure(L, &rspamd_stat_cache_checked, 0);
-
- if (lua_pcall(L, 3, 0, err_idx) != 0) {
- msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
- lua_settop(L, err_idx - 1);
- return RSPAMD_LEARN_IGNORE;
- }
-
- /* We need to return OK every time */
- return RSPAMD_LEARN_OK;
- }
-
- int rspamd_stat_cache_redis_learn(struct rspamd_task *task,
- gboolean is_spam,
- gpointer runtime)
- {
- auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
-
- if (rspamd_session_blocked(task->s)) {
- return RSPAMD_LEARN_IGNORE;
- }
-
- auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
- g_assert(h != nullptr);
- auto *L = ctx->L;
-
- lua_pushcfunction(L, &rspamd_lua_traceback);
- int err_idx = lua_gettop(L);
-
- /* Function arguments */
- lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
- rspamd_lua_task_push(L, task);
- lua_pushstring(L, h);
- lua_pushboolean(L, is_spam);
-
- if (lua_pcall(L, 3, 0, err_idx) != 0) {
- msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
- lua_settop(L, err_idx - 1);
- return RSPAMD_LEARN_IGNORE;
- }
-
- /* We need to return OK every time */
- return RSPAMD_LEARN_OK;
- }
-
- void rspamd_stat_cache_redis_close(gpointer c)
- {
- auto *ctx = (struct rspamd_redis_cache_ctx *) c;
- delete ctx;
- }
|