]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Use strings for int64_t
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 8 Dec 2023 09:33:57 +0000 (09:33 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 8 Dec 2023 09:33:57 +0000 (09:33 +0000)
It seems there is no easy way to use int64 in Redis Lua, hence, we have
to use strings. It's much more expensive but still some advantage over
the previous schema.

lualib/redis_scripts/bayes_classify.lua
lualib/redis_scripts/bayes_learn.lua
lualib/redis_scripts/bayes_stat.lua [new file with mode: 0644]
src/libstat/backends/redis_backend.cxx

index c999609e5de896bb082af8ba6f47ae5073f1808e..9bef96f145fd90371ed145f24656905a68476f6a 100644 (file)
@@ -1,10 +1,9 @@
 -- Lua script to perform bayes classification
 -- This script accepts the following parameters:
 -- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - set of tokens encoded in messagepack array of int64_t
+-- key2 - set of tokens encoded in messagepack array of strings
 
 local prefix = KEYS[1]
-local input_tokens = cmsgpack.unpack(KEYS[2])
 local output_spam = {}
 local output_ham = {}
 
@@ -17,8 +16,9 @@ local prefix_underscore = prefix .. '_'
 -- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
 
 if learned_ham > 0 and learned_spam > 0 then
+  local input_tokens = cmsgpack.unpack(KEYS[2])
   for i, token in ipairs(input_tokens) do
-    local token_data = redis.call('HMGET', prefix_underscore .. tostring(token), 'H', 'S')
+    local token_data = redis.call('HMGET', prefix_underscore .. token, 'H', 'S')
 
     if token_data then
       local ham_count = token_data[1]
index 6382547067b0561ce0988548b976d5070a97fd9d..7536f680852a397e016ea68bff15419b4a68a0ac 100644 (file)
@@ -4,7 +4,7 @@
 -- key2 - boolean is_spam
 -- key3 - string symbol
 -- key4 - boolean is_unlearn
--- key5 - set of tokens encoded in messagepack array of int64_t
+-- key5 - set of tokens encoded in messagepack array of strings
 
 local prefix = KEYS[1]
 local is_spam = KEYS[2] == 'true' and true or false
@@ -21,5 +21,5 @@ redis.call('HSET', prefix, 'version', '2') -- new schema
 redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count
 
 for _, token in ipairs(input_tokens) do
-  redis.call('HINCRBY', prefix_underscore .. tostring(token), hash_key, 1)
+  redis.call('HINCRBY', prefix_underscore .. token, hash_key, 1)
 end
\ No newline at end of file
diff --git a/lualib/redis_scripts/bayes_stat.lua b/lualib/redis_scripts/bayes_stat.lua
new file mode 100644 (file)
index 0000000..e69de29
index 342fa02739227f5ddedc47331c79800140a3b713..0eddf26cbe4f7d63e631eb179d8379ff26e78e01 100644 (file)
@@ -657,13 +657,13 @@ void rspamd_redis_close(gpointer p)
 static char *
 rspamd_redis_serialize_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len)
 {
-       /* Each token is int64_t that requires 9 bytes + 4 bytes array len + 1 byte array magic */
-       gsize req_len = tokens->len * 9 + 5, i;
-       gchar *buf, *p;
+       /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */
+       char max_int64_str[] = "18446744073709551615";
+       auto req_len = tokens->len * sizeof(max_int64_str) + 5;
        rspamd_token_t *tok;
 
-       buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
-       p = buf;
+       auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
+       auto *p = buf;
 
        /* Array */
        *p++ = (gchar) 0xdd;
@@ -673,13 +673,15 @@ rspamd_redis_serialize_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize
        *p++ = (gchar) ((tokens->len >> 8) & 0xff);
        *p++ = (gchar) (tokens->len & 0xff);
 
+       int i;
        PTR_ARRAY_FOREACH(tokens, i, tok)
        {
-               *p++ = (gchar) 0xd3;
+               char numbuf[sizeof(max_int64_str)];
+               auto r = rspamd_snprintf(numbuf, sizeof(numbuf), "%uL", tok->data);
+               *p++ = (gchar) ((r & 0xff) | 0xa0);
 
-               guint64 val = GUINT64_TO_BE(tok->data);
-               memcpy(p, &val, sizeof(val));
-               p += sizeof(val);
+               memcpy(p, &numbuf, r);
+               p += r;
        }
 
        *ser_len = p - buf;