summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2023-12-30 21:37:42 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2023-12-31 07:08:19 +0000
commit6d0cfe35ab90a19e932184d301d726d52248f09d (patch)
tree4b107e20167d2dce0d61e64ed4ad3f40f4181f05
parent38244b721b4ee82431c95cad69bac97dd54d44cc (diff)
downloadrspamd-6d0cfe35ab90a19e932184d301d726d52248f09d.tar.gz
rspamd-6d0cfe35ab90a19e932184d301d726d52248f09d.zip
[Project] Implement text tokens storage on C++ side
-rw-r--r--lualib/redis_scripts/bayes_learn.lua14
-rw-r--r--src/libstat/backends/redis_backend.cxx121
2 files changed, 128 insertions, 7 deletions
diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua
index b0ed1dd4b..bf2775349 100644
--- a/lualib/redis_scripts/bayes_learn.lua
+++ b/lualib/redis_scripts/bayes_learn.lua
@@ -31,12 +31,14 @@ for i, token in ipairs(input_tokens) do
local tok1 = text_tokens[i * 2 - 1]
local tok2 = text_tokens[i * 2]
- if tok2 then
- redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2))
- else
- redis.call('HSET', token, 'tokens', tok1)
- end
+ if tok1 then
+ if tok2 then
+ redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2))
+ else
+ redis.call('HSET', token, 'tokens', tok1)
+ end
- redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1)
+ redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1)
+ end
end
end \ No newline at end of file
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
index b91d93ac8..94576224d 100644
--- a/src/libstat/backends/redis_backend.cxx
+++ b/src/libstat/backends/redis_backend.cxx
@@ -682,6 +682,112 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPt
return buf;
}
+static char *
+rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len)
+{
+ rspamd_token_t *tok;
+ auto req_len = 5; /* Messagepack array prefix */
+ int i;
+
+ constexpr const auto msgpack_str_len = [](std::size_t len) {
+ if (len <= 0x1F) {
+ return 1 + len;
+ }
+ else if (len <= 0xff) {
+ return 2 + len;
+ }
+ else if (len <= 0xffff) {
+ return 3 + len;
+ }
+ else {
+ return 4 + len;
+ }
+ };
+ constexpr const auto msgpack_emit_str = [](const std::string_view st, char *out) -> size_t {
+ auto len = st.size();
+ constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
+ auto blen = 0;
+ if (len <= 0x1F) {
+ blen = 1;
+ out[0] = (len | fix_mask) & 0xff;
+ }
+ else if (len <= 0xff) {
+ blen = 2;
+ out[0] = l8_ch;
+ out[1] = len & 0xff;
+ }
+ else if (len <= 0xffff) {
+ uint16_t bl = GUINT16_TO_BE(len);
+
+ blen = 3;
+ out[0] = l16_ch;
+ memcpy(&out[1], &bl, sizeof(bl));
+ }
+ else {
+ uint32_t bl = GUINT32_TO_BE(len);
+
+ blen = 5;
+ out[0] = l32_ch;
+ memcpy(&out[1], &bl, sizeof(bl));
+ }
+
+ memcpy(&out[blen], st.data(), st.size());
+
+ return blen + len;
+ };
+ /*
+ * First we need to determine the requested length
+ */
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ if (tok->t1 && tok->t2) {
+ /* Two tokens */
+ req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len);
+ }
+ else if (tok->t1) {
+ req_len += msgpack_str_len(tok->t1->stemmed.len);
+ req_len += 1; /* null */
+ }
+ else {
+ req_len += 2; /* 2 nulls */
+ }
+ }
+
+ auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
+ auto *p = buf;
+
+ /* Array */
+ std::uint32_t nlen = tokens->len * 2;
+ nlen = GUINT32_TO_BE(nlen);
+ *p++ = (gchar) 0xdd;
+ /* Length in big-endian (4 bytes) */
+ memcpy(p, &nlen, sizeof(nlen));
+ p += sizeof(nlen);
+
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ if (tok->t1 && tok->t2) {
+ auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+ p += step;
+ step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p);
+ p += step;
+ }
+ else if (tok->t1) {
+ auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+ p += step;
+ *p++ = 0xc0;
+ }
+ else {
+ *p++ = 0xc0;
+ *p++ = 0xc0;
+ }
+ }
+
+ *ser_len = p - buf;
+
+ return buf;
+}
+
static gint
rspamd_redis_classified(lua_State *L)
{
@@ -881,8 +987,16 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
rt->id = id;
+ gsize text_tokens_len = 0;
+ gchar *text_tokens_buf = nullptr;
+
+ if (rt->ctx->store_tokens) {
+ text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len);
+ }
+
lua_pushcfunction(L, &rspamd_lua_traceback);
gint err_idx = lua_gettop(L);
+ auto nargs = 8;
/* Function arguments */
lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn);
@@ -903,6 +1017,11 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
}
lua_new_text(L, tokens_buf, tokens_len, false);
+ if (text_tokens_len) {
+ nargs = 9;
+ lua_new_text(L, text_tokens_buf, text_tokens_len, false);
+ }
+
/* Store rt in random cookie */
char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
rspamd_random_hex(cookie, 16);
@@ -912,7 +1031,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
lua_pushstring(L, cookie);
lua_pushcclosure(L, &rspamd_redis_learned, 1);
- if (lua_pcall(L, 8, 0, err_idx) != 0) {
+ if (lua_pcall(L, nargs, 0, err_idx) != 0) {
msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
lua_settop(L, err_idx - 1);
return FALSE;