aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/lua_bayes_redis.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/lua_bayes_redis.lua')
-rw-r--r--lualib/lua_bayes_redis.lua67
1 files changed, 49 insertions, 18 deletions
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua
index 782e6fc47..a7af80bf1 100644
--- a/lualib/lua_bayes_redis.lua
+++ b/lualib/lua_bayes_redis.lua
@@ -25,27 +25,44 @@ local ucl = require "ucl"
local N = "bayes"
local function gen_classify_functor(redis_params, classify_script_id)
- return function(task, expanded_key, id, is_spam, stat_tokens, callback)
-
+ return function(task, expanded_key, id, class_labels, stat_tokens, callback)
local function classify_redis_cb(err, data)
lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
if err then
callback(task, false, err)
else
- callback(task, true, data[1], data[2], data[3], data[4])
+ -- Pass the raw data table to the C++ callback for processing
+ -- The C++ callback will handle both binary and multi-class formats
+ callback(task, true, data)
+ end
+ end
+
+ -- Determine class labels to send to Redis script
+ local script_class_labels
+ if type(class_labels) == "table" then
+ -- Use simple comma-separated string instead of messagepack
+ script_class_labels = "TABLE:" .. table.concat(class_labels, ",")
+ else
+ -- Single class label or boolean compatibility
+ if class_labels == true or class_labels == "true" then
+ script_class_labels = "S" -- spam
+ elseif class_labels == false or class_labels == "false" then
+ script_class_labels = "H" -- ham
+ else
+ script_class_labels = class_labels -- string class label
end
end
lua_redis.exec_redis_script(classify_script_id,
{ task = task, is_write = false, key = expanded_key },
- classify_redis_cb, { expanded_key, stat_tokens })
+ classify_redis_cb, { expanded_key, script_class_labels, stat_tokens })
end
end
local function gen_learn_functor(redis_params, learn_script_id)
- return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
+ return function(task, expanded_key, id, class_label, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
local function learn_redis_cb(err, data)
- lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
+ lua_util.debugm(N, task, 'learn redis cb: %s, %s for class %s', err, data, class_label)
if err then
callback(task, false, err)
else
@@ -53,17 +70,24 @@ local function gen_learn_functor(redis_params, learn_script_id)
end
end
+ -- Convert class_label for backward compatibility
+ local script_class_label = class_label
+ if class_label == true or class_label == "true" then
+ script_class_label = "S" -- spam
+ elseif class_label == false or class_label == "false" then
+ script_class_label = "H" -- ham
+ end
+
if maybe_text_tokens then
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = expanded_key },
learn_redis_cb,
- { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
+ { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
else
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = expanded_key },
- learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
+ learn_redis_cb, { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens })
end
-
end
end
@@ -112,8 +136,7 @@ end
--- @param classifier_ucl ucl of the classifier config
--- @param statfile_ucl ucl of the statfile config
--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
-exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
-
+exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, class_label, ev_base, stat_periodic_cb)
local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
if not redis_params then
@@ -137,7 +160,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
if ev_base then
rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
-
local function stat_redis_cb(err, data)
lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
@@ -162,11 +184,22 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol,
end
end
+ -- Convert class_label to learn key
+ local learn_key
+ if class_label == true or class_label == "true" or class_label == "S" then
+ learn_key = "learns_spam"
+ elseif class_label == false or class_label == "false" or class_label == "H" then
+ learn_key = "learns_ham"
+ else
+ -- For other class labels, use learns_<class_label>
+ learn_key = "learns_" .. string.lower(tostring(class_label))
+ end
+
lua_redis.exec_redis_script(stat_script_id,
{ ev_base = ev_base, cfg = cfg, is_write = false },
stat_redis_cb, { tostring(cursor),
symbol,
- is_spam and "learns_spam" or "learns_ham",
+ learn_key,
tostring(max_users) })
return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
end)
@@ -178,7 +211,6 @@ end
local function gen_cache_check_functor(redis_params, check_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
return function(task, cache_id, callback)
-
local function classify_redis_cb(err, data)
lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
if err then
@@ -201,17 +233,16 @@ end
local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
local packed_conf = ucl.to_format(conf, 'msgpack')
- return function(task, cache_id, is_spam)
+ return function(task, cache_id, class_name, class_id)
local function learn_redis_cb(err, data)
lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
end
- lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
+ lua_util.debugm(N, task, 'try to learn cache: %s as %s (id=%s)', cache_id, class_name, class_id)
lua_redis.exec_redis_script(learn_script_id,
{ task = task, is_write = true, key = cache_id },
learn_redis_cb,
- { cache_id, is_spam and "1" or "0", packed_conf })
-
+ { cache_id, tostring(class_id), packed_conf })
end
end