diff options
Diffstat (limited to 'lualib/redis_scripts')
-rw-r--r-- | lualib/redis_scripts/bayes_cache_learn.lua | 17 | ||||
-rw-r--r-- | lualib/redis_scripts/bayes_classify.lua | 75 | ||||
-rw-r--r-- | lualib/redis_scripts/bayes_learn.lua | 55 |
3 files changed, 110 insertions, 37 deletions
diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua index 7d44a73ef..a7c9ac443 100644 --- a/lualib/redis_scripts/bayes_cache_learn.lua +++ b/lualib/redis_scripts/bayes_cache_learn.lua @@ -1,12 +1,15 @@ --- Lua script to perform cache checking for bayes classification +-- Lua script to perform cache checking for bayes classification (multi-class) -- This script accepts the following parameters: -- key1 - cache id --- key2 - is spam (1 or 0) +-- key2 - class_id (numeric hash of class name, computed by C side) -- key3 - configuration table in message pack local cache_id = KEYS[1] -local is_spam = KEYS[2] +local class_id = KEYS[2] local conf = cmsgpack.unpack(KEYS[3]) + +-- Use class_id directly as cache value +local cache_value = tostring(class_id) cache_id = string.sub(cache_id, 1, conf.cache_elt_len) -- Try each prefix that is in Redis (as some other instance might have set it) @@ -15,8 +18,8 @@ for i = 0, conf.cache_max_keys do local have = redis.call('HGET', prefix, cache_id) if have then - -- Already in cache, but is_spam changes when relearning - redis.call('HSET', prefix, cache_id, is_spam) + -- Already in cache, but cache_value changes when relearning + redis.call('HSET', prefix, cache_id, cache_value) return false end end @@ -30,7 +33,7 @@ for i = 0, conf.cache_max_keys do if count < lim then -- We can add it to this prefix - redis.call('HSET', prefix, cache_id, is_spam) + redis.call('HSET', prefix, cache_id, cache_value) added = true end end @@ -46,7 +49,7 @@ if not added then if exists then if not expired then redis.call('DEL', prefix) - redis.call('HSET', prefix, cache_id, is_spam) + redis.call('HSET', prefix, cache_id, cache_value) -- Do not expire anything else expired = true diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua index e94f645fd..d6132e631 100644 --- a/lualib/redis_scripts/bayes_classify.lua +++ b/lualib/redis_scripts/bayes_classify.lua @@ -1,37 +1,68 @@ --- Lua script to perform bayes classification +-- Lua script to perform bayes classification (multi-class) -- This script accepts the following parameters: -- key1 - prefix for bayes tokens (e.g. for per-user classification) --- key2 - set of tokens encoded in messagepack array of strings +-- key2 - class labels: table of all class labels as "TABLE:label1,label2,..." +-- key3 - set of tokens encoded in messagepack array of strings local prefix = KEYS[1] -local output_spam = {} -local output_ham = {} +local class_labels_arg = KEYS[2] +local input_tokens = cmsgpack.unpack(KEYS[3]) -local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0 -local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0 +-- Parse class labels (always expect TABLE: format) +local class_labels = {} +if string.match(class_labels_arg, "^TABLE:") then + local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix + for label in string.gmatch(labels_str, "([^,]+)") do + table.insert(class_labels, label) + end +else + -- Legacy single class - convert to array + class_labels = { class_labels_arg } +end --- Output is a set of pairs (token_index, token_count), tokens that are not --- found are not filled. --- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held +-- Get learned counts for all classes (ordered) +local learned_counts = {} +for _, label in ipairs(class_labels) do + local key = 'learns_' .. string.lower(label) + -- Handle legacy keys for backward compatibility + if label == 'H' then + key = 'learns_ham' + elseif label == 'S' then + key = 'learns_spam' + end + table.insert(learned_counts, tonumber(redis.call('HGET', prefix, key)) or 0) +end -if learned_ham > 0 and learned_spam > 0 then - local input_tokens = cmsgpack.unpack(KEYS[2]) - for i, token in ipairs(input_tokens) do - local token_data = redis.call('HMGET', token, 'H', 'S') +-- Get token data for all classes (ordered) +local token_results = {} +for i, _ in ipairs(class_labels) do + token_results[i] = {} +end - if token_data then - local ham_count = token_data[1] - local spam_count = token_data[2] +-- Check if we have any learning data +local has_learns = false +for _, count in ipairs(learned_counts) do + if count > 0 then + has_learns = true + break + end +end - if ham_count then - table.insert(output_ham, { i, tonumber(ham_count) }) - end +if has_learns then + -- Process each token + for i, token in ipairs(input_tokens) do + local token_data = redis.call('HMGET', token, unpack(class_labels)) - if spam_count then - table.insert(output_spam, { i, tonumber(spam_count) }) + if token_data then + for j, _ in ipairs(class_labels) do + local count = token_data[j] + if count and tonumber(count) > 0 then + table.insert(token_results[j], { i, tonumber(count) }) + end end end end end -return { learned_ham, learned_spam, output_ham, output_spam }
\ No newline at end of file +-- Always return ordered arrays: [learned_counts_array, token_results_array] +return { learned_counts, token_results } diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua index 5456165b6..ebc798fe0 100644 --- a/lualib/redis_scripts/bayes_learn.lua +++ b/lualib/redis_scripts/bayes_learn.lua @@ -1,14 +1,14 @@ --- Lua script to perform bayes learning +-- Lua script to perform bayes learning (multi-class) -- This script accepts the following parameters: -- key1 - prefix for bayes tokens (e.g. for per-user classification) --- key2 - boolean is_spam +-- key2 - class label string (e.g. "S", "H", "T") -- key3 - string symbol -- key4 - boolean is_unlearn -- key5 - set of tokens encoded in messagepack array of strings -- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`) local prefix = KEYS[1] -local is_spam = KEYS[2] == 'true' and true or false +local class_label = KEYS[2] local symbol = KEYS[3] local is_unlearn = KEYS[4] == 'true' and true or false local input_tokens = cmsgpack.unpack(KEYS[5]) @@ -18,15 +18,47 @@ if KEYS[6] then text_tokens = cmsgpack.unpack(KEYS[6]) end -local hash_key = is_spam and 'S' or 'H' -local learned_key = is_spam and 'learns_spam' or 'learns_ham' +-- Handle backward compatibility for boolean values +if class_label == 'true' then + class_label = 'S' -- spam +elseif class_label == 'false' then + class_label = 'H' -- ham +end + +local hash_key = class_label +local learned_key = 'learns_' .. string.lower(class_label) + +-- Handle legacy keys for backward compatibility +if class_label == 'S' then + learned_key = 'learns_spam' +elseif class_label == 'H' then + learned_key = 'learns_ham' +end redis.call('SADD', symbol .. '_keys', prefix) redis.call('HSET', prefix, 'version', '2') -- new schema -redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count + +-- Update learned count, but prevent it from going negative +if is_unlearn then + local current_count = tonumber(redis.call('HGET', prefix, learned_key)) or 0 + if current_count > 0 then + redis.call('HINCRBY', prefix, learned_key, -1) + end +else + redis.call('HINCRBY', prefix, learned_key, 1) +end for i, token in ipairs(input_tokens) do - redis.call('HINCRBY', token, hash_key, is_unlearn and -1 or 1) + -- Update token count, but prevent it from going negative + if is_unlearn then + local current_token_count = tonumber(redis.call('HGET', token, hash_key)) or 0 + if current_token_count > 0 then + redis.call('HINCRBY', token, hash_key, -1) + end + else + redis.call('HINCRBY', token, hash_key, 1) + end + if text_tokens then local tok1 = text_tokens[i * 2 - 1] local tok2 = text_tokens[i * 2] @@ -38,7 +70,14 @@ for i, token in ipairs(input_tokens) do redis.call('HSET', token, 'tokens', tok1) end - redis.call('ZINCRBY', prefix .. '_z', is_unlearn and -1 or 1, token) + if is_unlearn then + local current_z_score = tonumber(redis.call('ZSCORE', prefix .. '_z', token)) or 0 + if current_z_score > 0 then + redis.call('ZINCRBY', prefix .. '_z', -1, token) + end + else + redis.call('ZINCRBY', prefix .. '_z', 1, token) + end end end end |