aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/redis_scripts/bayes_classify.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/redis_scripts/bayes_classify.lua')
-rw-r--r--lualib/redis_scripts/bayes_classify.lua75
1 files changed, 53 insertions, 22 deletions
diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua
index e94f645fd..d6132e631 100644
--- a/lualib/redis_scripts/bayes_classify.lua
+++ b/lualib/redis_scripts/bayes_classify.lua
@@ -1,37 +1,68 @@
--- Lua script to perform bayes classification
+-- Lua script to perform bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - set of tokens encoded in messagepack array of strings
+-- key2 - class labels: table of all class labels as "TABLE:label1,label2,..."
+-- key3 - set of tokens encoded in messagepack array of strings
local prefix = KEYS[1]
-local output_spam = {}
-local output_ham = {}
+local class_labels_arg = KEYS[2]
+local input_tokens = cmsgpack.unpack(KEYS[3])
-local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0
-local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0
+-- Parse class labels (always expect TABLE: format)
+local class_labels = {}
+if string.match(class_labels_arg, "^TABLE:") then
+ local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix
+ for label in string.gmatch(labels_str, "([^,]+)") do
+ table.insert(class_labels, label)
+ end
+else
+ -- Legacy single class - convert to array
+ class_labels = { class_labels_arg }
+end
--- Output is a set of pairs (token_index, token_count), tokens that are not
--- found are not filled.
--- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
+-- Get learned counts for all classes (ordered)
+local learned_counts = {}
+for _, label in ipairs(class_labels) do
+ local key = 'learns_' .. string.lower(label)
+ -- Handle legacy keys for backward compatibility
+ if label == 'H' then
+ key = 'learns_ham'
+ elseif label == 'S' then
+ key = 'learns_spam'
+ end
+ table.insert(learned_counts, tonumber(redis.call('HGET', prefix, key)) or 0)
+end
-if learned_ham > 0 and learned_spam > 0 then
- local input_tokens = cmsgpack.unpack(KEYS[2])
- for i, token in ipairs(input_tokens) do
- local token_data = redis.call('HMGET', token, 'H', 'S')
+-- Get token data for all classes (ordered)
+local token_results = {}
+for i, _ in ipairs(class_labels) do
+ token_results[i] = {}
+end
- if token_data then
- local ham_count = token_data[1]
- local spam_count = token_data[2]
+-- Check if we have any learning data
+local has_learns = false
+for _, count in ipairs(learned_counts) do
+ if count > 0 then
+ has_learns = true
+ break
+ end
+end
- if ham_count then
- table.insert(output_ham, { i, tonumber(ham_count) })
- end
+if has_learns then
+ -- Process each token
+ for i, token in ipairs(input_tokens) do
+ local token_data = redis.call('HMGET', token, unpack(class_labels))
- if spam_count then
- table.insert(output_spam, { i, tonumber(spam_count) })
+ if token_data then
+ for j, _ in ipairs(class_labels) do
+ local count = token_data[j]
+ if count and tonumber(count) > 0 then
+ table.insert(token_results[j], { i, tonumber(count) })
+ end
end
end
end
end
-return { learned_ham, learned_spam, output_ham, output_spam } \ No newline at end of file
+-- Always return ordered arrays: [learned_counts_array, token_results_array]
+return { learned_counts, token_results }