aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/redis_scripts
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/redis_scripts')
-rw-r--r--lualib/redis_scripts/bayes_cache_learn.lua17
-rw-r--r--lualib/redis_scripts/bayes_classify.lua75
-rw-r--r--lualib/redis_scripts/bayes_learn.lua55
3 files changed, 110 insertions, 37 deletions
diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua
index 7d44a73ef..a7c9ac443 100644
--- a/lualib/redis_scripts/bayes_cache_learn.lua
+++ b/lualib/redis_scripts/bayes_cache_learn.lua
@@ -1,12 +1,15 @@
--- Lua script to perform cache checking for bayes classification
+-- Lua script to perform cache checking for bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - cache id
--- key2 - is spam (1 or 0)
+-- key2 - class_id (numeric hash of class name, computed by C side)
-- key3 - configuration table in message pack
local cache_id = KEYS[1]
-local is_spam = KEYS[2]
+local class_id = KEYS[2]
local conf = cmsgpack.unpack(KEYS[3])
+
+-- Use class_id directly as cache value
+local cache_value = tostring(class_id)
cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
-- Try each prefix that is in Redis (as some other instance might have set it)
@@ -15,8 +18,8 @@ for i = 0, conf.cache_max_keys do
local have = redis.call('HGET', prefix, cache_id)
if have then
- -- Already in cache, but is_spam changes when relearning
- redis.call('HSET', prefix, cache_id, is_spam)
+ -- Already in cache, but cache_value changes when relearning
+ redis.call('HSET', prefix, cache_id, cache_value)
return false
end
end
@@ -30,7 +33,7 @@ for i = 0, conf.cache_max_keys do
if count < lim then
-- We can add it to this prefix
- redis.call('HSET', prefix, cache_id, is_spam)
+ redis.call('HSET', prefix, cache_id, cache_value)
added = true
end
end
@@ -46,7 +49,7 @@ if not added then
if exists then
if not expired then
redis.call('DEL', prefix)
- redis.call('HSET', prefix, cache_id, is_spam)
+ redis.call('HSET', prefix, cache_id, cache_value)
-- Do not expire anything else
expired = true
diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua
index e94f645fd..d6132e631 100644
--- a/lualib/redis_scripts/bayes_classify.lua
+++ b/lualib/redis_scripts/bayes_classify.lua
@@ -1,37 +1,68 @@
--- Lua script to perform bayes classification
+-- Lua script to perform bayes classification (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - set of tokens encoded in messagepack array of strings
+-- key2 - class labels: table of all class labels as "TABLE:label1,label2,..."
+-- key3 - set of tokens encoded in messagepack array of strings
local prefix = KEYS[1]
-local output_spam = {}
-local output_ham = {}
+local class_labels_arg = KEYS[2]
+local input_tokens = cmsgpack.unpack(KEYS[3])
-local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0
-local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0
+-- Parse class labels (always expect TABLE: format)
+local class_labels = {}
+if string.match(class_labels_arg, "^TABLE:") then
+ local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix
+ for label in string.gmatch(labels_str, "([^,]+)") do
+ table.insert(class_labels, label)
+ end
+else
+ -- Legacy single class - convert to array
+ class_labels = { class_labels_arg }
+end
--- Output is a set of pairs (token_index, token_count), tokens that are not
--- found are not filled.
--- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
+-- Get learned counts for all classes (ordered)
+local learned_counts = {}
+for _, label in ipairs(class_labels) do
+ local key = 'learns_' .. string.lower(label)
+ -- Handle legacy keys for backward compatibility
+ if label == 'H' then
+ key = 'learns_ham'
+ elseif label == 'S' then
+ key = 'learns_spam'
+ end
+ table.insert(learned_counts, tonumber(redis.call('HGET', prefix, key)) or 0)
+end
-if learned_ham > 0 and learned_spam > 0 then
- local input_tokens = cmsgpack.unpack(KEYS[2])
- for i, token in ipairs(input_tokens) do
- local token_data = redis.call('HMGET', token, 'H', 'S')
+-- Get token data for all classes (ordered)
+local token_results = {}
+for i, _ in ipairs(class_labels) do
+ token_results[i] = {}
+end
- if token_data then
- local ham_count = token_data[1]
- local spam_count = token_data[2]
+-- Check if we have any learning data
+local has_learns = false
+for _, count in ipairs(learned_counts) do
+ if count > 0 then
+ has_learns = true
+ break
+ end
+end
- if ham_count then
- table.insert(output_ham, { i, tonumber(ham_count) })
- end
+if has_learns then
+ -- Process each token
+ for i, token in ipairs(input_tokens) do
+ local token_data = redis.call('HMGET', token, unpack(class_labels))
- if spam_count then
- table.insert(output_spam, { i, tonumber(spam_count) })
+ if token_data then
+ for j, _ in ipairs(class_labels) do
+ local count = token_data[j]
+ if count and tonumber(count) > 0 then
+ table.insert(token_results[j], { i, tonumber(count) })
+ end
end
end
end
end
-return { learned_ham, learned_spam, output_ham, output_spam } \ No newline at end of file
+-- Always return ordered arrays: [learned_counts_array, token_results_array]
+return { learned_counts, token_results }
diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua
index 5456165b6..ebc798fe0 100644
--- a/lualib/redis_scripts/bayes_learn.lua
+++ b/lualib/redis_scripts/bayes_learn.lua
@@ -1,14 +1,14 @@
--- Lua script to perform bayes learning
+-- Lua script to perform bayes learning (multi-class)
-- This script accepts the following parameters:
-- key1 - prefix for bayes tokens (e.g. for per-user classification)
--- key2 - boolean is_spam
+-- key2 - class label string (e.g. "S", "H", "T")
-- key3 - string symbol
-- key4 - boolean is_unlearn
-- key5 - set of tokens encoded in messagepack array of strings
-- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`)
local prefix = KEYS[1]
-local is_spam = KEYS[2] == 'true' and true or false
+local class_label = KEYS[2]
local symbol = KEYS[3]
local is_unlearn = KEYS[4] == 'true' and true or false
local input_tokens = cmsgpack.unpack(KEYS[5])
@@ -18,15 +18,47 @@ if KEYS[6] then
text_tokens = cmsgpack.unpack(KEYS[6])
end
-local hash_key = is_spam and 'S' or 'H'
-local learned_key = is_spam and 'learns_spam' or 'learns_ham'
+-- Handle backward compatibility for boolean values
+if class_label == 'true' then
+ class_label = 'S' -- spam
+elseif class_label == 'false' then
+ class_label = 'H' -- ham
+end
+
+local hash_key = class_label
+local learned_key = 'learns_' .. string.lower(class_label)
+
+-- Handle legacy keys for backward compatibility
+if class_label == 'S' then
+ learned_key = 'learns_spam'
+elseif class_label == 'H' then
+ learned_key = 'learns_ham'
+end
redis.call('SADD', symbol .. '_keys', prefix)
redis.call('HSET', prefix, 'version', '2') -- new schema
-redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count
+
+-- Update learned count, but prevent it from going negative
+if is_unlearn then
+ local current_count = tonumber(redis.call('HGET', prefix, learned_key)) or 0
+ if current_count > 0 then
+ redis.call('HINCRBY', prefix, learned_key, -1)
+ end
+else
+ redis.call('HINCRBY', prefix, learned_key, 1)
+end
for i, token in ipairs(input_tokens) do
- redis.call('HINCRBY', token, hash_key, is_unlearn and -1 or 1)
+ -- Update token count, but prevent it from going negative
+ if is_unlearn then
+ local current_token_count = tonumber(redis.call('HGET', token, hash_key)) or 0
+ if current_token_count > 0 then
+ redis.call('HINCRBY', token, hash_key, -1)
+ end
+ else
+ redis.call('HINCRBY', token, hash_key, 1)
+ end
+
if text_tokens then
local tok1 = text_tokens[i * 2 - 1]
local tok2 = text_tokens[i * 2]
@@ -38,7 +70,14 @@ for i, token in ipairs(input_tokens) do
redis.call('HSET', token, 'tokens', tok1)
end
- redis.call('ZINCRBY', prefix .. '_z', is_unlearn and -1 or 1, token)
+ if is_unlearn then
+ local current_z_score = tonumber(redis.call('ZSCORE', prefix .. '_z', token)) or 0
+ if current_z_score > 0 then
+ redis.call('ZINCRBY', prefix .. '_z', -1, token)
+ end
+ else
+ redis.call('ZINCRBY', prefix .. '_z', 1, token)
+ end
end
end
end