diff options
Diffstat (limited to 'src/plugins/lua/bayes_expiry.lua')
-rw-r--r-- | src/plugins/lua/bayes_expiry.lua | 182 |
1 files changed, 124 insertions, 58 deletions
diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua index 44ff9dafa..0d78f2272 100644 --- a/src/plugins/lua/bayes_expiry.lua +++ b/src/plugins/lua/bayes_expiry.lua @@ -41,32 +41,38 @@ local template = {} local function check_redis_classifier(cls, cfg) -- Skip old classifiers if cls.new_schema then - local symbol_spam, symbol_ham + local class_symbols = {} + local class_labels = {} local expiry = (cls.expiry or cls.expire) if type(expiry) == 'table' then expiry = expiry[1] end - -- Load symbols from statfiles + -- Extract class_labels mapping from classifier config + if cls.class_labels then + class_labels = cls.class_labels + end + -- Load symbols from statfiles for multi-class support local function check_statfile_table(tbl, def_sym) local symbol = tbl.symbol or def_sym - - local spam - if tbl.spam then - spam = tbl.spam - else - if string.match(symbol:upper(), 'SPAM') then - spam = true + local class_name = tbl.class + + -- Handle legacy spam/ham detection for backward compatibility + if not class_name then + if tbl.spam ~= nil then + class_name = tbl.spam and 'spam' or 'ham' + elseif string.match(tostring(symbol):upper(), 'SPAM') then + class_name = 'spam' + elseif string.match(tostring(symbol):upper(), 'HAM') then + class_name = 'ham' else - spam = false + class_name = def_sym end end - if spam then - symbol_spam = symbol - else - symbol_ham = symbol + if class_name then + class_symbols[class_name] = symbol end end @@ -87,10 +93,9 @@ local function check_redis_classifier(cls, cfg) end end - if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then + if next(class_symbols) == nil or type(expiry) ~= 'number' then logger.debugm(N, rspamd_config, - 'disable expiry for classifier %s: no expiry %s', - symbol_spam, cls) + 'disable expiry for classifier: no class symbols or expiry configured') return end -- Now try to load redis_params if needed @@ -108,17 +113,16 @@ local function check_redis_classifier(cls, cfg) end if redis_params['read_only'] then - logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration', - symbol_spam) + logger.infox(rspamd_config, 'disable expiry for classifier: read only redis configuration') return end - logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry", - symbol_spam, symbol_ham, expiry) + logger.debugm(N, rspamd_config, "enabled expiry for classes %s -> %s expiry", + table.concat(lutil.keys(class_symbols), ', '), expiry) table.insert(settings.classifiers, { - symbol_spam = symbol_spam, - symbol_ham = symbol_ham, + class_symbols = class_symbols, + class_labels = class_labels, redis_params = redis_params, expiry = expiry }) @@ -249,12 +253,11 @@ local expiry_script = [[ local keys = ret[2] local tokens = {} - -- Tokens occurrences distribution counters + -- Dynamic occurrence tracking for all classes local occur = { - ham = {}, - spam = {}, total = {} } + local classes_found = {} -- Expiry step statistics counters local nelts, extended, discriminated, sum, sum_squares, common, significant, @@ -264,24 +267,44 @@ local expiry_script = [[ for _,key in ipairs(keys) do local t = redis.call('TYPE', key)["ok"] if t == 'hash' then - local values = redis.call('HMGET', key, 'H', 'S') - local ham = tonumber(values[1]) or 0 - local spam = tonumber(values[2]) or 0 + -- Get all hash fields to support multi-class + local hash_data = redis.call('HGETALL', key) + local class_counts = {} + local total = 0 local ttl = redis.call('TTL', key) + + -- Parse hash data into class counts + for i = 1, #hash_data, 2 do + local class_label = hash_data[i] + local count = tonumber(hash_data[i + 1]) or 0 + class_counts[class_label] = count + total = total + count + + -- Track classes we've seen + if not classes_found[class_label] then + classes_found[class_label] = true + occur[class_label] = {} + end + end + tokens[key] = { - ham, - spam, - ttl + class_counts = class_counts, + total = total, + ttl = ttl } - local total = spam + ham + sum = sum + total sum_squares = sum_squares + total * total nelts = nelts + 1 - for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do - if tonumber(v) > 19 then v = 20 end - occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1 + -- Update occurrence counters for all classes and total + for class_label, count in pairs(class_counts) do + local bucket = count > 19 and 20 or count + occur[class_label][bucket] = (occur[class_label][bucket] or 0) + 1 end + + local total_bucket = total > 19 and 20 or total + occur.total[total_bucket] = (occur.total[total_bucket] or 0) + 1 end end @@ -293,9 +316,10 @@ local expiry_script = [[ end for key,token in pairs(tokens) do - local ham, spam, ttl = token[1], token[2], tonumber(token[3]) + local class_counts = token.class_counts + local total = token.total + local ttl = tonumber(token.ttl) local threshold = mean - local total = spam + ham local function set_ttl() if expire < 0 then @@ -310,14 +334,39 @@ local expiry_script = [[ return 0 end - if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then + -- Check if token is common (balanced across classes) + local is_common = false + if total == 0 then + is_common = true + else + -- For multi-class, check if any class dominates significantly + local max_count = 0 + for _, count in pairs(class_counts) do + if count > max_count then + max_count = count + end + end + -- Token is common if no class has more than (1 - epsilon) of total + is_common = (max_count / total) <= (1 - ${epsilon_common}) + end + + if is_common then common = common + 1 if ttl > ${common_ttl} then discriminated = discriminated + 1 redis.call('EXPIRE', key, ${common_ttl}) end elseif total >= threshold and total > 0 then - if ham / total > ${significant_factor} or spam / total > ${significant_factor} then + -- Check if any class is significant + local is_significant = false + for _, count in pairs(class_counts) do + if count / total > ${significant_factor} then + is_significant = true + break + end + end + + if is_significant then significant = significant + 1 if ttl ~= -1 then redis.call('PERSIST', key) @@ -361,33 +410,50 @@ local expiry_script = [[ redis.call('DEL', lock_key) local occ_distr = {} - for _,cl in pairs({'ham', 'spam', 'total'}) do + + -- Process all classes found plus total + local all_classes = {'total'} + for class_label in pairs(classes_found) do + table.insert(all_classes, class_label) + end + + for _, cl in ipairs(all_classes) do local occur_key = pattern_sha1 .. '_occurrence_' .. cl if cursor ~= 0 then - local n - for i,v in ipairs(redis.call('HGETALL', occur_key)) do - if i % 2 == 1 then - n = tonumber(v) - else - occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v + local existing_data = redis.call('HGETALL', occur_key) + if #existing_data > 0 then + for i = 1, #existing_data, 2 do + local bucket = tonumber(existing_data[i]) + local count = tonumber(existing_data[i + 1]) + if occur[cl] and occur[cl][bucket] then + occur[cl][bucket] = occur[cl][bucket] + count + elseif occur[cl] then + occur[cl][bucket] = count + end end end - local str = '' - if occur[cl][0] ~= nil then - str = '0:' .. occur[cl][0] .. ',' - end - for k,v in ipairs(occur[cl]) do - if k == 20 then k = '>19' end - str = str .. k .. ':' .. v .. ',' + if occur[cl] and next(occur[cl]) then + local str = '' + if occur[cl][0] then + str = '0:' .. occur[cl][0] .. ',' + end + for k = 1, 20 do + if occur[cl][k] then + local label = k == 20 and '>19' or tostring(k) + str = str .. label .. ':' .. occur[cl][k] .. ',' + end + end + table.insert(occ_distr, cl .. '=' .. str) + else + table.insert(occ_distr, cl .. '=no_data') end - table.insert(occ_distr, str) else redis.call('DEL', occur_key) end - if next(occur[cl]) ~= nil then + if occur[cl] and next(occur[cl]) then redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl]))) end end @@ -446,8 +512,8 @@ local function expire_step(cls, ev_base, worker) '%s infrequent (%s %s), %s mean, %s std', lutil.unpack(d)) if cycle then - for i, cl in ipairs({ 'in ham', 'in spam', 'total' }) do - logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i]) + for _, distr_info in ipairs(occ_distr) do + logger.infox(rspamd_config, 'tokens occurrences: {%s}', distr_info) end end end |