From 6d0cfe35ab90a19e932184d301d726d52248f09d Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 30 Dec 2023 21:37:42 +0000 Subject: [PATCH] [Project] Implement text tokens storage on C++ side --- lualib/redis_scripts/bayes_learn.lua | 14 +-- src/libstat/backends/redis_backend.cxx | 121 ++++++++++++++++++++++++- 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua index b0ed1dd4b..bf2775349 100644 --- a/lualib/redis_scripts/bayes_learn.lua +++ b/lualib/redis_scripts/bayes_learn.lua @@ -31,12 +31,14 @@ for i, token in ipairs(input_tokens) do local tok1 = text_tokens[i * 2 - 1] local tok2 = text_tokens[i * 2] - if tok2 then - redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2)) - else - redis.call('HSET', token, 'tokens', tok1) - end + if tok1 then + if tok2 then + redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2)) + else + redis.call('HSET', token, 'tokens', tok1) + end - redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1) + redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1) + end end end \ No newline at end of file diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index b91d93ac8..94576224d 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -682,6 +682,112 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPt return buf; } +static char * +rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len) +{ + rspamd_token_t *tok; + auto req_len = 5; /* Messagepack array prefix */ + int i; + + constexpr const auto msgpack_str_len = [](std::size_t len) { + if (len <= 0x1F) { + return 1 + len; + } + else if (len <= 0xff) { + return 2 + len; + } + else if (len <= 0xffff) { + return 3 + len; + } + else { + return 4 + len; + } + }; + constexpr const auto msgpack_emit_str = [](const std::string_view st, char *out) -> size_t { + auto len = st.size(); + constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb; + auto blen = 0; + if (len <= 0x1F) { + blen = 1; + out[0] = (len | fix_mask) & 0xff; + } + else if (len <= 0xff) { + blen = 2; + out[0] = l8_ch; + out[1] = len & 0xff; + } + else if (len <= 0xffff) { + uint16_t bl = GUINT16_TO_BE(len); + + blen = 3; + out[0] = l16_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + else { + uint32_t bl = GUINT32_TO_BE(len); + + blen = 5; + out[0] = l32_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + + memcpy(&out[blen], st.data(), st.size()); + + return blen + len; + }; + /* + * First we need to determine the requested length + */ + PTR_ARRAY_FOREACH(tokens, i, tok) + { + if (tok->t1 && tok->t2) { + /* Two tokens */ + req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len); + } + else if (tok->t1) { + req_len += msgpack_str_len(tok->t1->stemmed.len); + req_len += 1; /* null */ + } + else { + req_len += 2; /* 2 nulls */ + } + } + + auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len); + auto *p = buf; + + /* Array */ + std::uint32_t nlen = tokens->len * 2; + nlen = GUINT32_TO_BE(nlen); + *p++ = (gchar) 0xdd; + /* Length in big-endian (4 bytes) */ + memcpy(p, &nlen, sizeof(nlen)); + p += sizeof(nlen); + + PTR_ARRAY_FOREACH(tokens, i, tok) + { + if (tok->t1 && tok->t2) { + auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p); + p += step; + step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p); + p += step; + } + else if (tok->t1) { + auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p); + p += step; + *p++ = 0xc0; + } + else { + *p++ = 0xc0; + *p++ = 0xc0; + } + } + + *ser_len = p - buf; + + return buf; +} + static gint rspamd_redis_classified(lua_State *L) { @@ -881,8 +987,16 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, rt->id = id; + gsize text_tokens_len = 0; + gchar *text_tokens_buf = nullptr; + + if (rt->ctx->store_tokens) { + text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len); + } + lua_pushcfunction(L, &rspamd_lua_traceback); gint err_idx = lua_gettop(L); + auto nargs = 8; /* Function arguments */ lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn); @@ -903,6 +1017,11 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, } lua_new_text(L, tokens_buf, tokens_len, false); + if (text_tokens_len) { + nargs = 9; + lua_new_text(L, text_tokens_buf, text_tokens_len, false); + } + /* Store rt in random cookie */ char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16); rspamd_random_hex(cookie, 16); @@ -912,7 +1031,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task, lua_pushstring(L, cookie); lua_pushcclosure(L, &rspamd_redis_learned, 1); - if (lua_pcall(L, 8, 0, err_idx) != 0) { + if (lua_pcall(L, nargs, 0, err_idx) != 0) { msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); return FALSE; -- 2.39.5