diff options
Diffstat (limited to 'lualib/lua_bayes_redis.lua')
-rw-r--r-- | lualib/lua_bayes_redis.lua | 67 |
1 files changed, 49 insertions, 18 deletions
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index 782e6fc47..a7af80bf1 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -25,27 +25,44 @@ local ucl = require "ucl" local N = "bayes" local function gen_classify_functor(redis_params, classify_script_id) - return function(task, expanded_key, id, is_spam, stat_tokens, callback) - + return function(task, expanded_key, id, class_labels, stat_tokens, callback) local function classify_redis_cb(err, data) lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data) if err then callback(task, false, err) else - callback(task, true, data[1], data[2], data[3], data[4]) + -- Pass the raw data table to the C++ callback for processing + -- The C++ callback will handle both binary and multi-class formats + callback(task, true, data) + end + end + + -- Determine class labels to send to Redis script + local script_class_labels + if type(class_labels) == "table" then + -- Use simple comma-separated string instead of messagepack + script_class_labels = "TABLE:" .. table.concat(class_labels, ",") + else + -- Single class label or boolean compatibility + if class_labels == true or class_labels == "true" then + script_class_labels = "S" -- spam + elseif class_labels == false or class_labels == "false" then + script_class_labels = "H" -- ham + else + script_class_labels = class_labels -- string class label end end lua_redis.exec_redis_script(classify_script_id, { task = task, is_write = false, key = expanded_key }, - classify_redis_cb, { expanded_key, stat_tokens }) + classify_redis_cb, { expanded_key, script_class_labels, stat_tokens }) end end local function gen_learn_functor(redis_params, learn_script_id) - return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens) + return function(task, expanded_key, id, class_label, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens) local function learn_redis_cb(err, data) - lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data) + lua_util.debugm(N, task, 'learn redis cb: %s, %s for class %s', err, data, class_label) if err then callback(task, false, err) else @@ -53,17 +70,24 @@ local function gen_learn_functor(redis_params, learn_script_id) end end + -- Convert class_label for backward compatibility + local script_class_label = class_label + if class_label == true or class_label == "true" then + script_class_label = "S" -- spam + elseif class_label == false or class_label == "false" then + script_class_label = "H" -- ham + end + if maybe_text_tokens then lua_redis.exec_redis_script(learn_script_id, { task = task, is_write = true, key = expanded_key }, learn_redis_cb, - { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) + { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) else lua_redis.exec_redis_script(learn_script_id, { task = task, is_write = true, key = expanded_key }, - learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) + learn_redis_cb, { expanded_key, script_class_label, symbol, tostring(is_unlearn), stat_tokens }) end - end end @@ -112,8 +136,7 @@ end --- @param classifier_ucl ucl of the classifier config --- @param statfile_ucl ucl of the statfile config --- @return a pair of (classify_functor, learn_functor) or `nil` in case of error -exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb) - +exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, class_label, ev_base, stat_periodic_cb) local redis_params = load_redis_params(classifier_ucl, statfile_ucl) if not redis_params then @@ -137,7 +160,6 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, if ev_base then rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _) - local function stat_redis_cb(err, data) lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data) @@ -162,11 +184,22 @@ exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, end end + -- Convert class_label to learn key + local learn_key + if class_label == true or class_label == "true" or class_label == "S" then + learn_key = "learns_spam" + elseif class_label == false or class_label == "false" or class_label == "H" then + learn_key = "learns_ham" + else + -- For other class labels, use learns_<class_label> + learn_key = "learns_" .. string.lower(tostring(class_label)) + end + lua_redis.exec_redis_script(stat_script_id, { ev_base = ev_base, cfg = cfg, is_write = false }, stat_redis_cb, { tostring(cursor), symbol, - is_spam and "learns_spam" or "learns_ham", + learn_key, tostring(max_users) }) return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0 end) @@ -178,7 +211,6 @@ end local function gen_cache_check_functor(redis_params, check_script_id, conf) local packed_conf = ucl.to_format(conf, 'msgpack') return function(task, cache_id, callback) - local function classify_redis_cb(err, data) lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data)) if err then @@ -201,17 +233,16 @@ end local function gen_cache_learn_functor(redis_params, learn_script_id, conf) local packed_conf = ucl.to_format(conf, 'msgpack') - return function(task, cache_id, is_spam) + return function(task, cache_id, class_name, class_id) local function learn_redis_cb(err, data) lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data) end - lua_util.debugm(N, task, 'try to learn cache: %s', cache_id) + lua_util.debugm(N, task, 'try to learn cache: %s as %s (id=%s)', cache_id, class_name, class_id) lua_redis.exec_redis_script(learn_script_id, { task = task, is_write = true, key = cache_id }, learn_redis_cb, - { cache_id, is_spam and "1" or "0", packed_conf }) - + { cache_id, tostring(class_id), packed_conf }) end end |