aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/redis_scripts
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2023-12-02 15:54:04 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2023-12-02 15:54:04 +0000
commit548251ac7ec47db643aba96b3cae0fc353495290 (patch)
tree4632cee745f8b54768940a375139ee306f025b39 /lualib/redis_scripts
parent44c6c563c2faf5ca97646e92e8382b972d2d513f (diff)
downloadrspamd-548251ac7ec47db643aba96b3cae0fc353495290.tar.gz
rspamd-548251ac7ec47db643aba96b3cae0fc353495290.zip
[Project] Optimise classify script
Diffstat (limited to 'lualib/redis_scripts')
-rw-r--r--lualib/redis_scripts/bayes_classify.lua38
1 files changed, 25 insertions, 13 deletions
diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua
index c2654e476..76e88a6f3 100644
--- a/lualib/redis_scripts/bayes_classify.lua
+++ b/lualib/redis_scripts/bayes_classify.lua
@@ -8,19 +8,31 @@ local input_tokens = cmsgpack.unpack(KEYS[2])
local output_spam = {}
local output_ham = {}
-for i, token in ipairs(input_tokens) do
- local token_data = redis.call('HMGET', prefix .. tostring(token), 'H', 'S')
-
- if token_data then
- local ham_count = tonumber(token_data[1]) or 0
- local spam_count = tonumber(token_data[2]) or 0
-
- output_ham[i] = ham_count
- output_spam[i] = spam_count
- else
- output_ham[i] = 0
- output_spam[i] = 0
+local learned_ham = redis.call('HGET', prefix, 'learned_ham') or 0
+local learned_spam = redis.call('HGET', prefix, 'learned_spam') or 0
+local prefix_underscore = prefix .. '_'
+
+-- Output is a set of pairs (token_index, token_count), tokens that are not
+-- found are not filled.
+-- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
+
+if learned_ham > 0 and learned_spam > 0 then
+ for i, token in ipairs(input_tokens) do
+ local token_data = redis.call('HMGET', prefix_underscore .. tostring(token), 'H', 'S')
+
+ if token_data then
+ local ham_count = token_data[1]
+ local spam_count = tonumber(token_data[2]) or 0
+
+ if ham_count then
+ table.insert(output_ham, { i, tonumber(ham_count) })
+ end
+
+ if spam_count then
+ table.insert(output_spam, { i, tonumber(spam_count) })
+ end
+ end
end
end
-return cmsgpack.pack({ output_ham, output_spam }) \ No newline at end of file
+return { learned_ham, learned_spam, output_ham, output_spam } \ No newline at end of file