]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Implement text tokens storage on C++ side
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 30 Dec 2023 21:37:42 +0000 (21:37 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 31 Dec 2023 07:08:19 +0000 (07:08 +0000)
lualib/redis_scripts/bayes_learn.lua
src/libstat/backends/redis_backend.cxx

index b0ed1dd4be9f4a6fea7ce7fca5d1766d14480000..bf27753493d06aebd558d1bf5ccb6970356a3517 100644 (file)
@@ -31,12 +31,14 @@ for i, token in ipairs(input_tokens) do
     local tok1 = text_tokens[i * 2 - 1]
     local tok2 = text_tokens[i * 2]
 
-    if tok2 then
-      redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2))
-    else
-      redis.call('HSET', token, 'tokens', tok1)
-    end
+    if tok1 then
+      if tok2 then
+        redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2))
+      else
+        redis.call('HSET', token, 'tokens', tok1)
+      end
 
-    redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1)
+      redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1)
+    end
   end
 end
\ No newline at end of file
index b91d93ac8aa8521f71730708293bfcfc562dabbf..94576224d337512510823fc256dee74ccb62e827 100644 (file)
@@ -682,6 +682,112 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPt
        return buf;
 }
 
+static char *
+rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len)
+{
+       rspamd_token_t *tok;
+       auto req_len = 5; /* Messagepack array prefix */
+       int i;
+
+       constexpr const auto msgpack_str_len = [](std::size_t len) {
+               if (len <= 0x1F) {
+                       return 1 + len;
+               }
+               else if (len <= 0xff) {
+                       return 2 + len;
+               }
+               else if (len <= 0xffff) {
+                       return 3 + len;
+               }
+               else {
+                       return 4 + len;
+               }
+       };
+       constexpr const auto msgpack_emit_str = [](const std::string_view st, char *out) -> size_t {
+               auto len = st.size();
+               constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
+               auto blen = 0;
+               if (len <= 0x1F) {
+                       blen = 1;
+                       out[0] = (len | fix_mask) & 0xff;
+               }
+               else if (len <= 0xff) {
+                       blen = 2;
+                       out[0] = l8_ch;
+                       out[1] = len & 0xff;
+               }
+               else if (len <= 0xffff) {
+                       uint16_t bl = GUINT16_TO_BE(len);
+
+                       blen = 3;
+                       out[0] = l16_ch;
+                       memcpy(&out[1], &bl, sizeof(bl));
+               }
+               else {
+                       uint32_t bl = GUINT32_TO_BE(len);
+
+                       blen = 5;
+                       out[0] = l32_ch;
+                       memcpy(&out[1], &bl, sizeof(bl));
+               }
+
+               memcpy(&out[blen], st.data(), st.size());
+
+               return blen + len;
+       };
+       /*
+        * First we need to determine the requested length
+        */
+       PTR_ARRAY_FOREACH(tokens, i, tok)
+       {
+               if (tok->t1 && tok->t2) {
+                       /* Two tokens */
+                       req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len);
+               }
+               else if (tok->t1) {
+                       req_len += msgpack_str_len(tok->t1->stemmed.len);
+                       req_len += 1; /* null */
+               }
+               else {
+                       req_len += 2; /* 2 nulls */
+               }
+       }
+
+       auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
+       auto *p = buf;
+
+       /* Array */
+       std::uint32_t nlen = tokens->len * 2;
+       nlen = GUINT32_TO_BE(nlen);
+       *p++ = (gchar) 0xdd;
+       /* Length in big-endian (4 bytes) */
+       memcpy(p, &nlen, sizeof(nlen));
+       p += sizeof(nlen);
+
+       PTR_ARRAY_FOREACH(tokens, i, tok)
+       {
+               if (tok->t1 && tok->t2) {
+                       auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+                       p += step;
+                       step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p);
+                       p += step;
+               }
+               else if (tok->t1) {
+                       auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+                       p += step;
+                       *p++ = 0xc0;
+               }
+               else {
+                       *p++ = 0xc0;
+                       *p++ = 0xc0;
+               }
+       }
+
+       *ser_len = p - buf;
+
+       return buf;
+}
+
 static gint
 rspamd_redis_classified(lua_State *L)
 {
@@ -881,8 +987,16 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
 
        rt->id = id;
 
+       gsize text_tokens_len = 0;
+       gchar *text_tokens_buf = nullptr;
+
+       if (rt->ctx->store_tokens) {
+               text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len);
+       }
+
        lua_pushcfunction(L, &rspamd_lua_traceback);
        gint err_idx = lua_gettop(L);
+       auto nargs = 8;
 
        /* Function arguments */
        lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn);
@@ -903,6 +1017,11 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
        }
        lua_new_text(L, tokens_buf, tokens_len, false);
 
+       if (text_tokens_len) {
+               nargs = 9;
+               lua_new_text(L, text_tokens_buf, text_tokens_len, false);
+       }
+
        /* Store rt in random cookie */
        char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
        rspamd_random_hex(cookie, 16);
@@ -912,7 +1031,7 @@ rspamd_redis_learn_tokens(struct rspamd_task *task,
        lua_pushstring(L, cookie);
        lua_pushcclosure(L, &rspamd_redis_learned, 1);
 
-       if (lua_pcall(L, 8, 0, err_idx) != 0) {
+       if (lua_pcall(L, nargs, 0, err_idx) != 0) {
                msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
                lua_settop(L, err_idx - 1);
                return FALSE;