diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-01-24 14:44:12 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-01-24 14:44:42 +0000 |
commit | f4aa0ba5a8a290b2bbc1e58b1ca302a5feb05ee5 (patch) | |
tree | 9eacd2e470ac467be9965d343e9e5d1f1094ab4d /src/libstat | |
parent | b1299ac9108f4568959a2ef61d2f7992ee131c88 (diff) | |
download | rspamd-f4aa0ba5a8a290b2bbc1e58b1ca302a5feb05ee5.tar.gz rspamd-f4aa0ba5a8a290b2bbc1e58b1ca302a5feb05ee5.zip |
[Fix] Fix learning with long prefixes
Issue: #4786
Closes: #4786
Diffstat (limited to 'src/libstat')
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 117 |
1 files changed, 63 insertions, 54 deletions
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index cff6baf8c..f5de68600 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -645,6 +645,58 @@ void rspamd_redis_close(gpointer p) delete ctx; } +static constexpr auto +msgpack_emit_str(const std::string_view st, char *out) -> std::size_t +{ + auto len = st.size(); + constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb; + auto blen = 0; + if (len <= 0x1F) { + blen = 1; + out[0] = (len | fix_mask) & 0xff; + } + else if (len <= 0xff) { + blen = 2; + out[0] = l8_ch; + out[1] = len & 0xff; + } + else if (len <= 0xffff) { + uint16_t bl = GUINT16_TO_BE(len); + + blen = 3; + out[0] = l16_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + else { + uint32_t bl = GUINT32_TO_BE(len); + + blen = 5; + out[0] = l32_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + + memcpy(&out[blen], st.data(), st.size()); + + return blen + len; +} + +static constexpr auto +msgpack_str_len(std::size_t len) -> std::size_t +{ + if (len <= 0x1F) { + return 1 + len; + } + else if (len <= 0xff) { + return 2 + len; + } + else if (len <= 0xffff) { + return 3 + len; + } + else { + return 4 + len; + } +} + /* * Serialise stat tokens to message pack */ @@ -654,9 +706,12 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPt /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */ char max_int64_str[] = "18446744073709551615"; auto prefix_len = strlen(prefix); - auto req_len = tokens->len * (sizeof(max_int64_str) + prefix_len + 1) + 5; + std::size_t req_len = 5; rspamd_token_t *tok; + /* Calculate required length */ + req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1); + auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len); auto *p = buf; @@ -668,16 +723,16 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPt *p++ = (gchar) ((tokens->len >> 8) & 0xff); *p++ = (gchar) (tokens->len & 0xff); + int i; + auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1; + auto *numbuf = (char *) g_alloca(numbuf_len); + PTR_ARRAY_FOREACH(tokens, i, tok) { - auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1; - auto *numbuf = (char *) g_alloca(numbuf_len); - auto r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data); - *p++ = (gchar) ((r & 0xff) | 0xa0); - - memcpy(p, numbuf, r); - p += r; + std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data); + auto shift = msgpack_emit_str({numbuf, r}, p); + p += shift; } *ser_len = p - buf; @@ -692,52 +747,6 @@ rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, auto req_len = 5; /* Messagepack array prefix */ int i; - constexpr const auto msgpack_str_len = [](std::size_t len) { - if (len <= 0x1F) { - return 1 + len; - } - else if (len <= 0xff) { - return 2 + len; - } - else if (len <= 0xffff) { - return 3 + len; - } - else { - return 4 + len; - } - }; - constexpr const auto msgpack_emit_str = [](const std::string_view st, char *out) -> size_t { - auto len = st.size(); - constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb; - auto blen = 0; - if (len <= 0x1F) { - blen = 1; - out[0] = (len | fix_mask) & 0xff; - } - else if (len <= 0xff) { - blen = 2; - out[0] = l8_ch; - out[1] = len & 0xff; - } - else if (len <= 0xffff) { - uint16_t bl = GUINT16_TO_BE(len); - - blen = 3; - out[0] = l16_ch; - memcpy(&out[1], &bl, sizeof(bl)); - } - else { - uint32_t bl = GUINT32_TO_BE(len); - - blen = 5; - out[0] = l32_ch; - memcpy(&out[1], &bl, sizeof(bl)); - } - - memcpy(&out[blen], st.data(), st.size()); - - return blen + len; - }; /* * First we need to determine the requested length */ |