12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139 |
- /*
- * Copyright 2024 Vsevolod Stakhov
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "config.h"
- #include "lua/lua_common.h"
- #include "rspamd.h"
- #include "stat_internal.h"
- #include "upstream.h"
- #include "libserver/mempool_vars_internal.h"
- #include "fmt/core.h"
-
- #include "libutil/cxx/error.hxx"
-
- #include <string>
- #include <cstdint>
- #include <vector>
- #include <optional>
-
- #define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \
- rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \
- RSPAMD_LOG_FUNC, \
- __VA_ARGS__)
-
- INIT_LOG_MODULE(stat_redis)
-
- #define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p))
- #define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p))
- #define REDIS_DEFAULT_OBJECT "%s%l"
- #define REDIS_DEFAULT_USERS_OBJECT "%s%l%r"
- #define REDIS_DEFAULT_TIMEOUT 0.5
- #define REDIS_STAT_TIMEOUT 30
- #define REDIS_MAX_USERS 1000
-
- struct redis_stat_ctx {
- lua_State *L;
- struct rspamd_statfile_config *stcf;
- const char *redis_object = REDIS_DEFAULT_OBJECT;
- bool enable_users = false;
- bool store_tokens = false;
- bool enable_signatures = false;
- int cbref_user = -1;
-
- int cbref_classify = -1;
- int cbref_learn = -1;
-
- ucl_object_t *cur_stat = nullptr;
-
- explicit redis_stat_ctx(lua_State *_L)
- : L(_L)
- {
- }
-
- ~redis_stat_ctx()
- {
- if (cbref_user != -1) {
- luaL_unref(L, LUA_REGISTRYINDEX, cbref_user);
- }
-
- if (cbref_classify != -1) {
- luaL_unref(L, LUA_REGISTRYINDEX, cbref_classify);
- }
-
- if (cbref_learn != -1) {
- luaL_unref(L, LUA_REGISTRYINDEX, cbref_learn);
- }
- }
- };
-
-
- template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
- struct redis_stat_runtime {
- struct redis_stat_ctx *ctx;
- struct rspamd_task *task;
- struct rspamd_statfile_config *stcf;
- GPtrArray *tokens = nullptr;
- const char *redis_object_expanded;
- std::uint64_t learned = 0;
- int id;
- std::vector<std::pair<int, T>> *results = nullptr;
- bool need_redis_call = true;
- std::optional<rspamd::util::error> err;
-
- using result_type = std::vector<std::pair<int, T>>;
-
- private:
- /* Called on connection termination */
- static void rt_dtor(gpointer data)
- {
- auto *rt = REDIS_RUNTIME(data);
-
- delete rt;
- }
-
- /* Avoid occasional deletion */
- ~redis_stat_runtime()
- {
- if (tokens) {
- g_ptr_array_unref(tokens);
- }
-
- delete results;
- }
-
- public:
- explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
- : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
- {
- rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this);
- }
-
- static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
- bool is_spam) -> std::optional<redis_stat_runtime<T> *>
- {
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
- auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
-
- if (res) {
- msg_debug_bayes("recovered runtime from mempool at %s", var_name.c_str());
- return reinterpret_cast<redis_stat_runtime<T> *>(res);
- }
- else {
- msg_debug_bayes("no runtime at %s", var_name.c_str());
- return std::nullopt;
- }
- }
-
- void set_results(std::vector<std::pair<int, T>> *results)
- {
- this->results = results;
- }
-
- /* Propagate results from internal representation to the tokens array */
- auto process_tokens(GPtrArray *tokens) const -> bool
- {
- rspamd_token_t *tok;
-
- if (!results) {
- return false;
- }
-
- for (auto [idx, val]: *results) {
- tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
- tok->values[id] = val;
- }
-
- return true;
- }
-
- auto save_in_mempool(bool is_spam) const
- {
- auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
- /* We do not set destructor for the variable, as it should be already added on creation */
- rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
- msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
- }
- };
-
- #define GET_TASK_ELT(task, elt) (task == nullptr ? nullptr : (task)->elt)
-
- static const gchar *M = "redis statistics";
-
- static GQuark
- rspamd_redis_stat_quark(void)
- {
- return g_quark_from_static_string(M);
- }
-
- /*
- * Non-static for lua unit testing
- */
- gsize rspamd_redis_expand_object(const gchar *pattern,
- struct redis_stat_ctx *ctx,
- struct rspamd_task *task,
- gchar **target)
- {
- gsize tlen = 0;
- const gchar *p = pattern, *elt;
- gchar *d, *end;
- enum {
- just_char,
- percent_char,
- mod_char
- } state = just_char;
- struct rspamd_statfile_config *stcf;
- lua_State *L = nullptr;
- struct rspamd_task **ptask;
- const gchar *rcpt = nullptr;
- gint err_idx;
-
- g_assert(ctx != nullptr);
- g_assert(task != nullptr);
- stcf = ctx->stcf;
-
- L = RSPAMD_LUA_CFG_STATE(task->cfg);
- g_assert(L != nullptr);
-
- if (ctx->enable_users) {
- if (ctx->cbref_user == -1) {
- rcpt = rspamd_task_get_principal_recipient(task);
- }
- else {
- /* Execute lua function to get userdata */
- lua_pushcfunction(L, &rspamd_lua_traceback);
- err_idx = lua_gettop(L);
-
- lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->cbref_user);
- ptask = (struct rspamd_task **) lua_newuserdata(L, sizeof(struct rspamd_task *));
- *ptask = task;
- rspamd_lua_setclass(L, rspamd_task_classname, -1);
-
- if (lua_pcall(L, 1, 1, err_idx) != 0) {
- msg_err_task("call to user extraction script failed: %s",
- lua_tostring(L, -1));
- }
- else {
- rcpt = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1));
- }
-
- /* Result + error function */
- lua_settop(L, err_idx - 1);
- }
-
- if (rcpt) {
- rspamd_mempool_set_variable(task->task_pool, "stat_user",
- (gpointer) rcpt, nullptr);
- }
- }
-
- /* Length calculation */
- while (*p) {
- switch (state) {
- case just_char:
- if (*p == '%') {
- state = percent_char;
- }
- else {
- tlen++;
- }
- p++;
- break;
- case percent_char:
- switch (*p) {
- case '%':
- tlen++;
- state = just_char;
- break;
- case 'u':
- elt = GET_TASK_ELT(task, auth_user);
- if (elt) {
- tlen += strlen(elt);
- }
- break;
- case 'r':
-
- if (rcpt == nullptr) {
- elt = rspamd_task_get_principal_recipient(task);
- }
- else {
- elt = rcpt;
- }
-
- if (elt) {
- tlen += strlen(elt);
- }
- break;
- case 'l':
- if (stcf->label) {
- tlen += strlen(stcf->label);
- }
- /* Label miss is OK */
- break;
- case 's':
- tlen += sizeof("RS") - 1;
- break;
- default:
- state = just_char;
- tlen++;
- break;
- }
-
- if (state == percent_char) {
- state = mod_char;
- }
- p++;
- break;
-
- case mod_char:
- switch (*p) {
- case 'd':
- p++;
- state = just_char;
- break;
- default:
- state = just_char;
- break;
- }
- break;
- }
- }
-
-
- if (target == nullptr) {
- return -1;
- }
-
- *target = (gchar *) rspamd_mempool_alloc(task->task_pool, tlen + 1);
- d = *target;
- end = d + tlen + 1;
- d[tlen] = '\0';
- p = pattern;
- state = just_char;
-
- /* Expand string */
- while (*p && d < end) {
- switch (state) {
- case just_char:
- if (*p == '%') {
- state = percent_char;
- }
- else {
- *d++ = *p;
- }
- p++;
- break;
- case percent_char:
- switch (*p) {
- case '%':
- *d++ = *p;
- state = just_char;
- break;
- case 'u':
- elt = GET_TASK_ELT(task, auth_user);
- if (elt) {
- d += rspamd_strlcpy(d, elt, end - d);
- }
- break;
- case 'r':
- if (rcpt == nullptr) {
- elt = rspamd_task_get_principal_recipient(task);
- }
- else {
- elt = rcpt;
- }
-
- if (elt) {
- d += rspamd_strlcpy(d, elt, end - d);
- }
- break;
- case 'l':
- if (stcf->label) {
- d += rspamd_strlcpy(d, stcf->label, end - d);
- }
- break;
- case 's':
- d += rspamd_strlcpy(d, "RS", end - d);
- break;
- default:
- state = just_char;
- *d++ = *p;
- break;
- }
-
- if (state == percent_char) {
- state = mod_char;
- }
- p++;
- break;
-
- case mod_char:
- switch (*p) {
- case 'd':
- /* TODO: not supported yet */
- p++;
- state = just_char;
- break;
- default:
- state = just_char;
- break;
- }
- break;
- }
- }
-
- return tlen;
- }
-
- static int
- rspamd_redis_stat_cb(lua_State *L)
- {
- const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
- auto *cfg = lua_check_config(L, 1);
- auto *backend = REDIS_CTX(rspamd_mempool_get_variable(cfg->cfg_pool, cookie));
-
- if (backend == nullptr) {
- msg_err("internal error: cookie %s is not found", cookie);
-
- return 0;
- }
-
- auto *cur_obj = ucl_object_lua_import(L, 2);
- msg_debug_bayes_cfg("got stat object for %s", backend->stcf->symbol);
- /* Enrich with some default values that are meaningless for redis */
- ucl_object_insert_key(cur_obj,
- ucl_object_typed_new(UCL_INT), "used", 0, false);
- ucl_object_insert_key(cur_obj,
- ucl_object_typed_new(UCL_INT), "total", 0, false);
- ucl_object_insert_key(cur_obj,
- ucl_object_typed_new(UCL_INT), "size", 0, false);
- ucl_object_insert_key(cur_obj,
- ucl_object_fromstring(backend->stcf->symbol),
- "symbol", 0, false);
- ucl_object_insert_key(cur_obj, ucl_object_fromstring("redis"),
- "type", 0, false);
- ucl_object_insert_key(cur_obj, ucl_object_fromint(0),
- "languages", 0, false);
-
- if (backend->cur_stat) {
- ucl_object_unref(backend->cur_stat);
- }
-
- backend->cur_stat = cur_obj;
-
- return 0;
- }
-
- static void
- rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
- const ucl_object_t *statfile_obj,
- const ucl_object_t *classifier_obj,
- struct rspamd_config *cfg)
- {
- const gchar *lua_script;
- const ucl_object_t *elt, *users_enabled;
- auto *L = RSPAMD_LUA_CFG_STATE(cfg);
-
- users_enabled = ucl_object_lookup_any(classifier_obj, "per_user",
- "users_enabled", nullptr);
-
- if (users_enabled != nullptr) {
- if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
- backend->enable_users = ucl_object_toboolean(users_enabled);
- backend->cbref_user = -1;
- }
- else if (ucl_object_type(users_enabled) == UCL_STRING) {
- lua_script = ucl_object_tostring(users_enabled);
-
- if (luaL_dostring(L, lua_script) != 0) {
- msg_err_config("cannot execute lua script for users "
- "extraction: %s",
- lua_tostring(L, -1));
- }
- else {
- if (lua_type(L, -1) == LUA_TFUNCTION) {
- backend->enable_users = TRUE;
- backend->cbref_user = luaL_ref(L,
- LUA_REGISTRYINDEX);
- }
- else {
- msg_err_config("lua script must return "
- "function(task) and not %s",
- lua_typename(L, lua_type(L, -1)));
- }
- }
- }
- }
- else {
- backend->enable_users = FALSE;
- backend->cbref_user = -1;
- }
-
- elt = ucl_object_lookup(classifier_obj, "prefix");
- if (elt == nullptr || ucl_object_type(elt) != UCL_STRING) {
- /* Default non-users statistics */
- if (backend->enable_users || backend->cbref_user != -1) {
- backend->redis_object = REDIS_DEFAULT_USERS_OBJECT;
- }
- else {
- backend->redis_object = REDIS_DEFAULT_OBJECT;
- }
- }
- else {
- /* XXX: sanity check */
- backend->redis_object = ucl_object_tostring(elt);
- }
-
- elt = ucl_object_lookup(classifier_obj, "store_tokens");
- if (elt) {
- backend->store_tokens = ucl_object_toboolean(elt);
- }
- else {
- backend->store_tokens = FALSE;
- }
-
- elt = ucl_object_lookup(classifier_obj, "signatures");
- if (elt) {
- backend->enable_signatures = ucl_object_toboolean(elt);
- }
- else {
- backend->enable_signatures = FALSE;
- }
- }
-
- gpointer
- rspamd_redis_init(struct rspamd_stat_ctx *ctx,
- struct rspamd_config *cfg, struct rspamd_statfile *st)
- {
- auto *L = RSPAMD_LUA_CFG_STATE(cfg);
-
- auto backend = std::make_unique<struct redis_stat_ctx>(L);
- lua_settop(L, 0);
-
- rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg);
-
- st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
- backend->stcf = st->stcf;
-
- lua_pushcfunction(L, &rspamd_lua_traceback);
- auto err_idx = lua_gettop(L);
-
- /* Obtain function */
- if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_statfile")) {
- msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_statfile");
- lua_settop(L, err_idx - 1);
-
- return nullptr;
- }
-
- /* Push arguments */
- ucl_object_push_lua(L, st->classifier->cfg->opts, false);
- ucl_object_push_lua(L, st->stcf->opts, false);
- lua_pushstring(L, backend->stcf->symbol);
- lua_pushboolean(L, backend->stcf->is_spam);
-
- /* Push event loop if there is one available (e.g. we are not in rspamadm mode) */
- if (ctx->event_loop) {
- auto **pev_base = (struct ev_loop **) lua_newuserdata(L, sizeof(struct ev_loop *));
- *pev_base = ctx->event_loop;
- rspamd_lua_setclass(L, "rspamd{ev_base}", -1);
- }
- else {
- lua_pushnil(L);
- }
-
- /* Store backend in random cookie */
- char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16);
- rspamd_random_hex(cookie, 16);
- cookie[15] = '\0';
- rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr);
- /* Callback + 1 upvalue */
- lua_pushstring(L, cookie);
- lua_pushcclosure(L, &rspamd_redis_stat_cb, 1);
-
- if (lua_pcall(L, 6, 2, err_idx) != 0) {
- msg_err("call to lua_bayes_init_classifier "
- "script failed: %s",
- lua_tostring(L, -1));
- lua_settop(L, err_idx - 1);
-
- return nullptr;
- }
-
- /* Results are in the stack:
- * top - 1 - classifier function (idx = -2)
- * top - learn function (idx = -1)
- */
-
- lua_pushvalue(L, -2);
- backend->cbref_classify = luaL_ref(L, LUA_REGISTRYINDEX);
-
- lua_pushvalue(L, -1);
- backend->cbref_learn = luaL_ref(L, LUA_REGISTRYINDEX);
-
- lua_settop(L, err_idx - 1);
-
- return backend.release();
- }
-
- gpointer
- rspamd_redis_runtime(struct rspamd_task *task,
- struct rspamd_statfile_config *stcf,
- gboolean learn, gpointer c, gint _id)
- {
- struct redis_stat_ctx *ctx = REDIS_CTX(c);
- char *object_expanded = nullptr;
-
- g_assert(ctx != nullptr);
- g_assert(stcf != nullptr);
-
- if (rspamd_redis_expand_object(ctx->redis_object, ctx, task,
- &object_expanded) == 0) {
- msg_err_task("expansion for %s failed for symbol %s "
- "(maybe learning per user classifier with no user or recipient)",
- learn ? "learning" : "classifying",
- stcf->symbol);
- return nullptr;
- }
-
- /* Look for the cached results */
- if (!learn) {
- auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded, stcf->is_spam);
-
- if (maybe_existing) {
- auto *rt = maybe_existing.value();
- /* Update stcf and ctx to correspond to what we have been asked */
- rt->stcf = stcf;
- rt->ctx = ctx;
- return rt;
- }
- }
-
- /* No cached result (or learn), create new one */
- auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
-
- if (!learn) {
- /*
- * For check, we also need to create the opposite class runtime to avoid
- * double call for Redis scripts.
- * This runtime will be filled later.
- */
- auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- object_expanded,
- !stcf->is_spam);
-
- if (!maybe_opposite_rt) {
- auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
- opposite_rt->save_in_mempool(!stcf->is_spam);
- opposite_rt->need_redis_call = false;
- }
- }
-
- rt->save_in_mempool(stcf->is_spam);
-
- return rt;
- }
-
- void rspamd_redis_close(gpointer p)
- {
- struct redis_stat_ctx *ctx = REDIS_CTX(p);
- delete ctx;
- }
-
- static constexpr auto
- msgpack_emit_str(const std::string_view st, char *out) -> std::size_t
- {
- auto len = st.size();
- constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
- auto blen = 0;
- if (len <= 0x1F) {
- blen = 1;
- out[0] = (len | fix_mask) & 0xff;
- }
- else if (len <= 0xff) {
- blen = 2;
- out[0] = l8_ch;
- out[1] = len & 0xff;
- }
- else if (len <= 0xffff) {
- uint16_t bl = GUINT16_TO_BE(len);
-
- blen = 3;
- out[0] = l16_ch;
- memcpy(&out[1], &bl, sizeof(bl));
- }
- else {
- uint32_t bl = GUINT32_TO_BE(len);
-
- blen = 5;
- out[0] = l32_ch;
- memcpy(&out[1], &bl, sizeof(bl));
- }
-
- memcpy(&out[blen], st.data(), st.size());
-
- return blen + len;
- }
-
- static constexpr auto
- msgpack_str_len(std::size_t len) -> std::size_t
- {
- if (len <= 0x1F) {
- return 1 + len;
- }
- else if (len <= 0xff) {
- return 2 + len;
- }
- else if (len <= 0xffff) {
- return 3 + len;
- }
- else {
- return 4 + len;
- }
- }
-
- /*
- * Serialise stat tokens to message pack
- */
- static char *
- rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPtrArray *tokens, gsize *ser_len)
- {
- /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */
- char max_int64_str[] = "18446744073709551615";
- auto prefix_len = strlen(prefix);
- std::size_t req_len = 5;
- rspamd_token_t *tok;
-
- /* Calculate required length */
- req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1);
-
- auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
- auto *p = buf;
-
- /* Array */
- *p++ = (gchar) 0xdd;
- /* Length in big-endian (4 bytes) */
- *p++ = (gchar) ((tokens->len >> 24) & 0xff);
- *p++ = (gchar) ((tokens->len >> 16) & 0xff);
- *p++ = (gchar) ((tokens->len >> 8) & 0xff);
- *p++ = (gchar) (tokens->len & 0xff);
-
-
- int i;
- auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1;
- auto *numbuf = (char *) g_alloca(numbuf_len);
-
- PTR_ARRAY_FOREACH(tokens, i, tok)
- {
- std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data);
- auto shift = msgpack_emit_str({numbuf, r}, p);
- p += shift;
- }
-
- *ser_len = p - buf;
-
- return buf;
- }
-
- static char *
- rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len)
- {
- rspamd_token_t *tok;
- auto req_len = 5; /* Messagepack array prefix */
- int i;
-
- /*
- * First we need to determine the requested length
- */
- PTR_ARRAY_FOREACH(tokens, i, tok)
- {
- if (tok->t1 && tok->t2) {
- /* Two tokens */
- req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len);
- }
- else if (tok->t1) {
- req_len += msgpack_str_len(tok->t1->stemmed.len);
- req_len += 1; /* null */
- }
- else {
- req_len += 2; /* 2 nulls */
- }
- }
-
- auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
- auto *p = buf;
-
- /* Array */
- std::uint32_t nlen = tokens->len * 2;
- nlen = GUINT32_TO_BE(nlen);
- *p++ = (gchar) 0xdd;
- /* Length in big-endian (4 bytes) */
- memcpy(p, &nlen, sizeof(nlen));
- p += sizeof(nlen);
-
- PTR_ARRAY_FOREACH(tokens, i, tok)
- {
- if (tok->t1 && tok->t2) {
- auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
- p += step;
- step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p);
- p += step;
- }
- else if (tok->t1) {
- auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
- p += step;
- *p++ = 0xc0;
- }
- else {
- *p++ = 0xc0;
- *p++ = 0xc0;
- }
- }
-
- *ser_len = p - buf;
-
- return buf;
- }
-
- static gint
- rspamd_redis_classified(lua_State *L)
- {
- const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
- auto *task = lua_check_task(L, 1);
- auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
-
- if (rt == nullptr) {
- msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
-
- return 0;
- }
-
- bool result = lua_toboolean(L, 2);
-
- if (result) {
- /* Indexes:
- * 3 - learned_ham (int)
- * 4 - learned_spam (int)
- * 5 - ham_tokens (pair<int, int>)
- * 6 - spam_tokens (pair<int, int>)
- */
-
- /*
- * We need to fill our runtime AND the opposite runtime
- */
- auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
- rt->learned = learned;
- redis_stat_runtime<float>::result_type *res;
-
- res = new redis_stat_runtime<float>::result_type();
-
- for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
- lua_rawgeti(L, -1, 1);
- auto idx = lua_tointeger(L, -1);
- lua_pop(L, 1);
-
- lua_rawgeti(L, -1, 2);
- auto value = lua_tonumber(L, -1);
- lua_pop(L, 1);
-
- res->emplace_back(idx, value);
- }
-
- rt->set_results(res);
- };
-
- auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
- rt->redis_object_expanded,
- !rt->stcf->is_spam);
-
- if (!opposite_rt_maybe) {
- msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
-
- return 0;
- }
-
- if (rt->stcf->is_spam) {
- filler_func(rt, L, lua_tointeger(L, 4), 6);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
- }
- else {
- filler_func(rt, L, lua_tointeger(L, 3), 5);
- filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
- }
-
- /* Mark task as being processed */
- task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
-
- /* Process all tokens */
- g_assert(rt->tokens != nullptr);
- rt->process_tokens(rt->tokens);
- opposite_rt_maybe.value()->process_tokens(rt->tokens);
- }
- else {
- /* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot classify task: %s",
- err_msg);
- }
-
- return 0;
- }
-
- gboolean
- rspamd_redis_process_tokens(struct rspamd_task *task,
- GPtrArray *tokens,
- gint id, gpointer p)
- {
- auto *rt = REDIS_RUNTIME(p);
- auto *L = rt->ctx->L;
-
- if (rspamd_session_blocked(task->s)) {
- return FALSE;
- }
-
- if (tokens == nullptr || tokens->len == 0) {
- return FALSE;
- }
-
- if (!rt->need_redis_call) {
- /* No need to do anything, as it is already done in the opposite class processing */
- /* However, we need to store id as it is needed for further tokens processing */
- rt->id = id;
- rt->tokens = g_ptr_array_ref(tokens);
-
- return TRUE;
- }
-
- gsize tokens_len;
- gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
- rt->id = id;
-
- lua_pushcfunction(L, &rspamd_lua_traceback);
- gint err_idx = lua_gettop(L);
-
- /* Function arguments */
- lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_classify);
- rspamd_lua_task_push(L, task);
- lua_pushstring(L, rt->redis_object_expanded);
- lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
- lua_new_text(L, tokens_buf, tokens_len, false);
-
- /* Store rt in random cookie */
- char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
- rspamd_random_hex(cookie, 16);
- cookie[15] = '\0';
- rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
- /* Callback */
- lua_pushstring(L, cookie);
- lua_pushcclosure(L, &rspamd_redis_classified, 1);
-
- if (lua_pcall(L, 6, 0, err_idx) != 0) {
- msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
- lua_settop(L, err_idx - 1);
- return FALSE;
- }
-
- rt->tokens = g_ptr_array_ref(tokens);
-
- lua_settop(L, err_idx - 1);
- return TRUE;
- }
-
- gboolean
- rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
- gpointer ctx)
- {
- auto *rt = REDIS_RUNTIME(runtime);
-
- return !rt->err.has_value();
- }
-
-
- static gint
- rspamd_redis_learned(lua_State *L)
- {
- const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
- auto *task = lua_check_task(L, 1);
- auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
-
- if (rt == nullptr) {
- msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
-
- return 0;
- }
-
- bool result = lua_toboolean(L, 2);
-
- if (result) {
- /* TODO: write it */
- }
- else {
- /* Error message is on index 3 */
- const auto *err_msg = lua_tostring(L, 3);
- rt->err = rspamd::util::error(err_msg, 500);
- msg_err_task("cannot learn task: %s", err_msg);
- }
-
- return 0;
- }
-
- gboolean
- rspamd_redis_learn_tokens(struct rspamd_task *task,
- GPtrArray *tokens,
- gint id, gpointer p)
- {
- auto *rt = REDIS_RUNTIME(p);
- auto *L = rt->ctx->L;
-
- if (rspamd_session_blocked(task->s)) {
- return FALSE;
- }
-
- if (tokens == nullptr || tokens->len == 0) {
- return FALSE;
- }
-
- gsize tokens_len;
- gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
-
- rt->id = id;
-
- gsize text_tokens_len = 0;
- gchar *text_tokens_buf = nullptr;
-
- if (rt->ctx->store_tokens) {
- text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len);
- }
-
- lua_pushcfunction(L, &rspamd_lua_traceback);
- gint err_idx = lua_gettop(L);
- auto nargs = 8;
-
- /* Function arguments */
- lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn);
- rspamd_lua_task_push(L, task);
- lua_pushstring(L, rt->redis_object_expanded);
- lua_pushinteger(L, id);
- lua_pushboolean(L, rt->stcf->is_spam);
- lua_pushstring(L, rt->stcf->symbol);
-
- /* Detect unlearn */
- auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0);
-
- if (tok->values[id] > 0) {
- lua_pushboolean(L, FALSE);// Learn
- }
- else {
- lua_pushboolean(L, TRUE);// Unlearn
- }
- lua_new_text(L, tokens_buf, tokens_len, false);
-
- /* Store rt in random cookie */
- char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
- rspamd_random_hex(cookie, 16);
- cookie[15] = '\0';
- rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
- /* Callback */
- lua_pushstring(L, cookie);
- lua_pushcclosure(L, &rspamd_redis_learned, 1);
-
- if (text_tokens_len) {
- nargs = 9;
- lua_new_text(L, text_tokens_buf, text_tokens_len, false);
- }
-
- if (lua_pcall(L, nargs, 0, err_idx) != 0) {
- msg_err_task("call to script failed: %s", lua_tostring(L, -1));
- lua_settop(L, err_idx - 1);
- return FALSE;
- }
-
- rt->tokens = g_ptr_array_ref(tokens);
-
- lua_settop(L, err_idx - 1);
- return TRUE;
- }
-
-
- gboolean
- rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
- gpointer ctx, GError **err)
- {
- auto *rt = REDIS_RUNTIME(runtime);
-
- if (rt->err.has_value()) {
- rt->err->into_g_error_set(rspamd_redis_stat_quark(), err);
-
- return FALSE;
- }
-
- return TRUE;
- }
-
- gulong
- rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime,
- gpointer ctx)
- {
- auto *rt = REDIS_RUNTIME(runtime);
-
- return rt->learned;
- }
-
- gulong
- rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime,
- gpointer ctx)
- {
- auto *rt = REDIS_RUNTIME(runtime);
-
- /* XXX: may cause races */
- return rt->learned + 1;
- }
-
- gulong
- rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime,
- gpointer ctx)
- {
- auto *rt = REDIS_RUNTIME(runtime);
-
- /* XXX: may cause races */
- return rt->learned + 1;
- }
-
- gulong
- rspamd_redis_learns(struct rspamd_task *task, gpointer runtime,
- gpointer ctx)
- {
- auto *rt = REDIS_RUNTIME(runtime);
-
- return rt->learned;
- }
-
- ucl_object_t *
- rspamd_redis_get_stat(gpointer runtime,
- gpointer ctx)
- {
- auto *rt = REDIS_RUNTIME(runtime);
-
- return ucl_object_ref(rt->ctx->cur_stat);
- }
-
- gpointer
- rspamd_redis_load_tokenizer_config(gpointer runtime,
- gsize *len)
- {
- return nullptr;
- }
|