]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Fix learning with long prefixes
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 24 Jan 2024 14:44:12 +0000 (14:44 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 24 Jan 2024 14:44:42 +0000 (14:44 +0000)
Issue: #4786
Closes: #4786
src/libstat/backends/redis_backend.cxx

index cff6baf8cd0931146484875f0d04c78194077317..f5de686005a7ede60085c156c2550cd3bea1369e 100644 (file)
@@ -645,6 +645,58 @@ void rspamd_redis_close(gpointer p)
        delete ctx;
 }
 
+static constexpr auto
+msgpack_emit_str(const std::string_view st, char *out) -> std::size_t
+{
+       auto len = st.size();
+       constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
+       auto blen = 0;
+       if (len <= 0x1F) {
+               blen = 1;
+               out[0] = (len | fix_mask) & 0xff;
+       }
+       else if (len <= 0xff) {
+               blen = 2;
+               out[0] = l8_ch;
+               out[1] = len & 0xff;
+       }
+       else if (len <= 0xffff) {
+               uint16_t bl = GUINT16_TO_BE(len);
+
+               blen = 3;
+               out[0] = l16_ch;
+               memcpy(&out[1], &bl, sizeof(bl));
+       }
+       else {
+               uint32_t bl = GUINT32_TO_BE(len);
+
+               blen = 5;
+               out[0] = l32_ch;
+               memcpy(&out[1], &bl, sizeof(bl));
+       }
+
+       memcpy(&out[blen], st.data(), st.size());
+
+       return blen + len;
+}
+
+static constexpr auto
+msgpack_str_len(std::size_t len) -> std::size_t
+{
+       if (len <= 0x1F) {
+               return 1 + len;
+       }
+       else if (len <= 0xff) {
+               return 2 + len;
+       }
+       else if (len <= 0xffff) {
+               return 3 + len;
+       }
+       else {
+               return 4 + len;
+       }
+}
+
 /*
  * Serialise stat tokens to message pack
  */
@@ -654,9 +706,12 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPt
        /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */
        char max_int64_str[] = "18446744073709551615";
        auto prefix_len = strlen(prefix);
-       auto req_len = tokens->len * (sizeof(max_int64_str) + prefix_len + 1) + 5;
+       std::size_t req_len = 5;
        rspamd_token_t *tok;
 
+       /* Calculate required length */
+       req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1);
+
        auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
        auto *p = buf;
 
@@ -668,16 +723,16 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPt
        *p++ = (gchar) ((tokens->len >> 8) & 0xff);
        *p++ = (gchar) (tokens->len & 0xff);
 
+
        int i;
+       auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1;
+       auto *numbuf = (char *) g_alloca(numbuf_len);
+
        PTR_ARRAY_FOREACH(tokens, i, tok)
        {
-               auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1;
-               auto *numbuf = (char *) g_alloca(numbuf_len);
-               auto r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data);
-               *p++ = (gchar) ((r & 0xff) | 0xa0);
-
-               memcpy(p, numbuf, r);
-               p += r;
+               std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data);
+               auto shift = msgpack_emit_str({numbuf, r}, p);
+               p += shift;
        }
 
        *ser_len = p - buf;
@@ -692,52 +747,6 @@ rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens,
        auto req_len = 5; /* Messagepack array prefix */
        int i;
 
-       constexpr const auto msgpack_str_len = [](std::size_t len) {
-               if (len <= 0x1F) {
-                       return 1 + len;
-               }
-               else if (len <= 0xff) {
-                       return 2 + len;
-               }
-               else if (len <= 0xffff) {
-                       return 3 + len;
-               }
-               else {
-                       return 4 + len;
-               }
-       };
-       constexpr const auto msgpack_emit_str = [](const std::string_view st, char *out) -> size_t {
-               auto len = st.size();
-               constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
-               auto blen = 0;
-               if (len <= 0x1F) {
-                       blen = 1;
-                       out[0] = (len | fix_mask) & 0xff;
-               }
-               else if (len <= 0xff) {
-                       blen = 2;
-                       out[0] = l8_ch;
-                       out[1] = len & 0xff;
-               }
-               else if (len <= 0xffff) {
-                       uint16_t bl = GUINT16_TO_BE(len);
-
-                       blen = 3;
-                       out[0] = l16_ch;
-                       memcpy(&out[1], &bl, sizeof(bl));
-               }
-               else {
-                       uint32_t bl = GUINT32_TO_BE(len);
-
-                       blen = 5;
-                       out[0] = l32_ch;
-                       memcpy(&out[1], &bl, sizeof(bl));
-               }
-
-               memcpy(&out[blen], st.data(), st.size());
-
-               return blen + len;
-       };
        /*
         * First we need to determine the requested length
         */