]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add store tokens support in new bayes learn
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 30 Dec 2023 21:06:46 +0000 (21:06 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 31 Dec 2023 07:08:18 +0000 (07:08 +0000)
lualib/lua_bayes_redis.lua
lualib/redis_scripts/bayes_learn.lua

index 96a505518396209364dbcd9058691e5089e1cf86..5ad5c35148a866ac11311ad6118e7a7e038a0bf6 100644 (file)
@@ -42,7 +42,7 @@ local function gen_classify_functor(redis_params, classify_script_id)
 end
 
 local function gen_learn_functor(redis_params, learn_script_id)
-  return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback)
+  return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
     local function learn_redis_cb(err, data)
       lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
       if err then
@@ -52,9 +52,17 @@ local function gen_learn_functor(redis_params, learn_script_id)
       end
     end
 
-    lua_redis.exec_redis_script(learn_script_id,
-        { task = task, is_write = false, key = expanded_key },
-        learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
+    if maybe_text_tokens then
+      lua_redis.exec_redis_script(learn_script_id,
+          { task = task, is_write = false, key = expanded_key },
+          learn_redis_cb,
+          { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
+    else
+      lua_redis.exec_redis_script(learn_script_id,
+          { task = task, is_write = false, key = expanded_key },
+          learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
+    end
+
   end
 end
 
index 244be43f66bc1115b64dcea28803cf186b80776c..b0ed1dd4be9f4a6fea7ce7fca5d1766d14480000 100644 (file)
@@ -5,12 +5,18 @@
 -- key3 - string symbol
 -- key4 - boolean is_unlearn
 -- key5 - set of tokens encoded in messagepack array of strings
+-- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`)
 
 local prefix = KEYS[1]
 local is_spam = KEYS[2] == 'true' and true or false
 local symbol = KEYS[3]
 local is_unlearn = KEYS[4] == 'true' and true or false
 local input_tokens = cmsgpack.unpack(KEYS[5])
+local text_tokens
+
+if KEYS[6] then
+  text_tokens = cmsgpack.unpack(KEYS[6])
+end
 
 local hash_key = is_spam and 'S' or 'H'
 local learned_key = is_spam and 'learns_spam' or 'learns_ham'
@@ -19,6 +25,18 @@ redis.call('SADD', symbol .. '_keys', prefix)
 redis.call('HSET', prefix, 'version', '2') -- new schema
 redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count
 
-for _, token in ipairs(input_tokens) do
+for i, token in ipairs(input_tokens) do
   redis.call('HINCRBY', token, hash_key, 1)
+  if text_tokens then
+    local tok1 = text_tokens[i * 2 - 1]
+    local tok2 = text_tokens[i * 2]
+
+    if tok2 then
+      redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2))
+    else
+      redis.call('HSET', token, 'tokens', tok1)
+    end
+
+    redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1)
+  end
 end
\ No newline at end of file