aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/arc.lua198
-rw-r--r--src/plugins/lua/bayes_expiry.lua182
2 files changed, 252 insertions, 128 deletions
diff --git a/src/plugins/lua/arc.lua b/src/plugins/lua/arc.lua
index 45da1f5a2..954583ed0 100644
--- a/src/plugins/lua/arc.lua
+++ b/src/plugins/lua/arc.lua
@@ -72,12 +72,13 @@ local settings = {
use_domain = 'header',
use_esld = true,
use_redis = false,
- key_prefix = 'arc_keys', -- default hash name
- reuse_auth_results = false, -- Reuse the existing authentication results
+ key_prefix = 'arc_keys', -- default hash name
+ reuse_auth_results = false, -- Reuse the existing authentication results
whitelisted_signers_map = nil, -- Trusted signers domains
- adjust_dmarc = true, -- Adjust DMARC rejected policy for trusted forwarders
- allowed_ids = nil, -- Allowed settings id
- forbidden_ids = nil, -- Banned settings id
+ whitelist = nil, -- Domains with broken ARC implementations to trust despite validation failures
+ adjust_dmarc = true, -- Adjust DMARC rejected policy for trusted forwarders
+ allowed_ids = nil, -- Allowed settings id
+ forbidden_ids = nil, -- Banned settings id
}
-- To match normal AR
@@ -86,15 +87,15 @@ local ar_settings = lua_auth_results.default_settings
local function parse_arc_header(hdr, target, is_aar)
-- Split elements by ';' and trim spaces
local arr = fun.totable(fun.map(
- function(val)
- return fun.totable(fun.map(lua_util.rspamd_str_trim,
- fun.filter(function(v)
- return v and #v > 0
- end,
- lua_util.rspamd_str_split(val.decoded, ';')
- )
- ))
- end, hdr
+ function(val)
+ return fun.totable(fun.map(lua_util.rspamd_str_trim,
+ fun.filter(function(v)
+ return v and #v > 0
+ end,
+ lua_util.rspamd_str_split(val.decoded, ';')
+ )
+ ))
+ end, hdr
))
-- v[1] is the key and v[2] is the value
@@ -115,11 +116,11 @@ local function parse_arc_header(hdr, target, is_aar)
if not is_aar then
-- For normal ARC headers we split by kv pair, like k=v
fun.each(function(v)
- fill_arc_header_table(v, target[i])
- end,
- fun.map(function(elt)
- return lua_util.rspamd_str_split(elt, '=')
- end, elts)
+ fill_arc_header_table(v, target[i])
+ end,
+ fun.map(function(elt)
+ return lua_util.rspamd_str_split(elt, '=')
+ end, elts)
)
else
-- For AAR we check special case of i=%d and pass everything else to
@@ -156,14 +157,14 @@ local function arc_validate_seals(task, seals, sigs, seal_headers, sig_headers)
for i = 1, #seals do
if (sigs[i].i or 0) ~= i then
fail_reason = string.format('bad i for signature: %d, expected %d; d=%s',
- sigs[i].i, i, sigs[i].d)
+ sigs[i].i, i, sigs[i].d)
rspamd_logger.infox(task, fail_reason)
task:insert_result(arc_symbols['invalid'], 1.0, fail_reason)
return false, fail_reason
end
if (seals[i].i or 0) ~= i then
fail_reason = string.format('bad i for seal: %d, expected %d; d=%s',
- seals[i].i, i, seals[i].d)
+ seals[i].i, i, seals[i].d)
rspamd_logger.infox(task, fail_reason)
task:insert_result(arc_symbols['invalid'], 1.0, fail_reason)
return false, fail_reason
@@ -207,7 +208,7 @@ local function arc_callback(task)
if #arc_sig_headers ~= #arc_seal_headers then
-- We mandate that count of seals is equal to count of signatures
rspamd_logger.infox(task, 'number of seals (%s) is not equal to number of signatures (%s)',
- #arc_seal_headers, #arc_sig_headers)
+ #arc_seal_headers, #arc_sig_headers)
task:insert_result(arc_symbols['invalid'], 1.0, 'invalid count of seals and signatures')
return
end
@@ -249,7 +250,7 @@ local function arc_callback(task)
-- Now check sanity of what we have
local valid, validation_error = arc_validate_seals(task, cbdata.seals, cbdata.sigs,
- arc_seal_headers, arc_sig_headers)
+ arc_seal_headers, arc_sig_headers)
if not valid then
task:cache_set('arc-failure', validation_error)
return
@@ -267,12 +268,20 @@ local function arc_callback(task)
local function gen_arc_seal_cb(index, sig)
return function(_, res, err, domain)
lua_util.debugm(N, task, 'checked arc seal: %s(%s), %s processed',
- res, err, index)
+ res, err, index)
if not res then
- cbdata.res = 'fail'
- if err and domain then
- table.insert(cbdata.errors, string.format('sig:%s:%s', domain, err))
+ -- Check if this domain is whitelisted for broken ARC implementations
+ if settings.whitelist and domain and settings.whitelist:get_key(domain) then
+ rspamd_logger.infox(task, 'ARC seal validation failed for whitelisted domain %s, treating as valid: %s',
+ domain, err)
+ lua_util.debugm(N, task, 'whitelisted domain %s ARC seal failure ignored', domain)
+ res = true -- Treat as valid to continue the chain
+ else
+ cbdata.res = 'fail'
+ if err and domain then
+ table.insert(cbdata.errors, string.format('sig:%s:%s', domain, err))
+ end
end
end
@@ -283,7 +292,7 @@ local function arc_callback(task)
local cur_aar = cbdata.ars[index]
if not cur_aar then
rspamd_logger.warnx(task, "cannot find Arc-Authentication-Results for trusted " ..
- "forwarder %s on i=%s", domain, cbdata.index)
+ "forwarder %s on i=%s", domain, cbdata.index)
else
task:cache_set(AR_TRUSTED_CACHE_KEY, cur_aar)
local seen_dmarc
@@ -309,20 +318,20 @@ local function arc_callback(task)
end
end
task:insert_result(arc_symbols.trusted_allow, mult,
- string.format('%s:s=%s:i=%d', domain, sig.s, index))
+ string.format('%s:s=%s:i=%d', domain, sig.s, index))
end
end
if index == #arc_sig_headers then
if cbdata.res == 'success' then
local arc_allow_result = string.format('%s:s=%s:i=%d',
- domain, sig.s, index)
+ domain, sig.s, index)
task:insert_result(arc_symbols.allow, 1.0, arc_allow_result)
task:cache_set('arc-allow', arc_allow_result)
else
task:insert_result(arc_symbols.reject, 1.0,
- rspamd_logger.slog('seal check failed: %s, %s', cbdata.res,
- cbdata.errors))
+ rspamd_logger.slog('seal check failed: %s, %s', cbdata.res,
+ cbdata.errors))
end
end
end
@@ -330,12 +339,20 @@ local function arc_callback(task)
local function arc_signature_cb(_, res, err, domain)
lua_util.debugm(N, task, 'checked arc signature %s: %s(%s)',
- domain, res, err)
+ domain, res, err)
if not res then
- cbdata.res = 'fail'
- if err and domain then
- table.insert(cbdata.errors, string.format('sig:%s:%s', domain, err))
+ -- Check if this domain is whitelisted for broken ARC implementations
+ if settings.whitelist and domain and settings.whitelist:get_key(domain) then
+ rspamd_logger.infox(task, 'ARC signature validation failed for whitelisted domain %s, treating as valid: %s',
+ domain, err)
+ lua_util.debugm(N, task, 'whitelisted domain %s ARC signature failure ignored', domain)
+ res = true -- Treat as valid to continue the chain
+ else
+ cbdata.res = 'fail'
+ if err and domain then
+ table.insert(cbdata.errors, string.format('sig:%s:%s', domain, err))
+ end
end
end
if cbdata.res == 'success' then
@@ -343,17 +360,24 @@ local function arc_callback(task)
for i, sig in ipairs(cbdata.seals) do
local ret, lerr = dkim_verify(task, sig.header, gen_arc_seal_cb(i, sig), 'arc-seal')
if not ret then
- cbdata.res = 'fail'
- table.insert(cbdata.errors, string.format('seal:%s:s=%s:i=%s:%s',
+ -- Check if this domain is whitelisted for broken ARC implementations
+ if settings.whitelist and sig.d and settings.whitelist:get_key(sig.d) then
+ rspamd_logger.infox(task, 'ARC seal dkim_verify failed for whitelisted domain %s, treating as valid: %s',
+ sig.d, lerr)
+ lua_util.debugm(N, task, 'whitelisted domain %s ARC seal dkim_verify failure ignored', sig.d)
+ else
+ cbdata.res = 'fail'
+ table.insert(cbdata.errors, string.format('seal:%s:s=%s:i=%s:%s',
sig.d or '', sig.s or '', sig.i or '', lerr))
- lua_util.debugm(N, task, 'checked arc seal %s: %s(%s), %s processed',
+ lua_util.debugm(N, task, 'checked arc seal %s: %s(%s), %s processed',
sig.d, ret, lerr, i)
+ end
end
end
else
task:insert_result(arc_symbols['reject'], 1.0,
- rspamd_logger.slog('signature check failed: %s, %s', cbdata.res,
- cbdata.errors))
+ rspamd_logger.slog('signature check failed: %s, %s', cbdata.res,
+ cbdata.errors))
end
end
@@ -397,25 +421,33 @@ local function arc_callback(task)
is "fail" and the algorithm stops here.
9. If the algorithm reaches this step, then the Chain Validation
Status is "pass", and the algorithm is complete.
- ]]--
+ ]] --
local processed = 0
local sig = cbdata.sigs[#cbdata.sigs] -- last AMS
local ret, err = dkim_verify(task, sig.header, arc_signature_cb, 'arc-sign')
if not ret then
- cbdata.res = 'fail'
- table.insert(cbdata.errors, string.format('sig:%s:%s', sig.d or '', err))
+ -- Check if this domain is whitelisted for broken ARC implementations
+ if settings.whitelist and sig.d and settings.whitelist:get_key(sig.d) then
+ rspamd_logger.infox(task, 'ARC signature dkim_verify failed for whitelisted domain %s, treating as valid: %s',
+ sig.d, err)
+ lua_util.debugm(N, task, 'whitelisted domain %s ARC signature dkim_verify failure ignored', sig.d)
+ processed = processed + 1
+ else
+ cbdata.res = 'fail'
+ table.insert(cbdata.errors, string.format('sig:%s:%s', sig.d or '', err))
+ end
else
processed = processed + 1
lua_util.debugm(N, task, 'processed arc signature %s[%s]: %s(%s), %s total',
- sig.d, sig.i, ret, err, #cbdata.seals)
+ sig.d, sig.i, ret, err, #cbdata.seals)
end
if processed == 0 then
task:insert_result(arc_symbols['reject'], 1.0,
- rspamd_logger.slog('cannot verify %s of %s signatures: %s',
- #arc_sig_headers - processed, #arc_sig_headers, cbdata.errors))
+ rspamd_logger.slog('cannot verify %s of %s signatures: %s',
+ #arc_sig_headers - processed, #arc_sig_headers, cbdata.errors))
end
end
@@ -538,13 +570,13 @@ local function arc_sign_seal(task, params, header)
for i = 1, #arc_seals, 1 do
if arc_auth_results[i] then
local s = dkim_canonicalize('ARC-Authentication-Results',
- arc_auth_results[i].raw_header)
+ arc_auth_results[i].raw_header)
sha_ctx:update(s)
lua_util.debugm(N, task, 'update signature with header: %s', s)
end
if arc_sigs[i] then
local s = dkim_canonicalize('ARC-Message-Signature',
- arc_sigs[i].raw_header)
+ arc_sigs[i].raw_header)
sha_ctx:update(s)
lua_util.debugm(N, task, 'update signature with header: %s', s)
end
@@ -557,16 +589,16 @@ local function arc_sign_seal(task, params, header)
end
header = lua_util.fold_header(task,
- 'ARC-Message-Signature',
- header)
+ 'ARC-Message-Signature',
+ header)
cur_auth_results = string.format('i=%d; %s', cur_idx, cur_auth_results)
cur_auth_results = lua_util.fold_header(task,
- 'ARC-Authentication-Results',
- cur_auth_results, ';')
+ 'ARC-Authentication-Results',
+ cur_auth_results, ';')
local s = dkim_canonicalize('ARC-Authentication-Results',
- cur_auth_results)
+ cur_auth_results)
sha_ctx:update(s)
lua_util.debugm(N, task, 'update signature with header: %s', s)
s = dkim_canonicalize('ARC-Message-Signature', header)
@@ -574,10 +606,10 @@ local function arc_sign_seal(task, params, header)
lua_util.debugm(N, task, 'update signature with header: %s', s)
local cur_arc_seal = string.format('i=%d; s=%s; d=%s; t=%d; a=rsa-sha256; cv=%s; b=',
- cur_idx,
- params.selector,
- params.domain,
- math.floor(rspamd_util.get_time()), params.arc_cv)
+ cur_idx,
+ params.selector,
+ params.domain,
+ math.floor(rspamd_util.get_time()), params.arc_cv)
s = string.format('%s:%s', 'arc-seal', cur_arc_seal)
sha_ctx:update(s)
lua_util.debugm(N, task, 'initial update signature with header: %s', s)
@@ -591,20 +623,23 @@ local function arc_sign_seal(task, params, header)
local sig = rspamd_rsa.sign_memory(privkey, sha_ctx:bin())
cur_arc_seal = string.format('%s%s', cur_arc_seal,
- sig:base64(70, nl_type))
+ sig:base64(70, nl_type))
lua_mime.modify_headers(task, {
add = {
['ARC-Authentication-Results'] = { order = 1, value = cur_auth_results },
['ARC-Message-Signature'] = { order = 1, value = header },
- ['ARC-Seal'] = { order = 1, value = lua_util.fold_header(task,
- 'ARC-Seal', cur_arc_seal) }
+ ['ARC-Seal'] = {
+ order = 1,
+ value = lua_util.fold_header(task,
+ 'ARC-Seal', cur_arc_seal)
+ }
},
-- RFC requires a strict order for these headers to be inserted
order = { 'ARC-Authentication-Results', 'ARC-Message-Signature', 'ARC-Seal' },
})
task:insert_result(settings.sign_symbol, 1.0,
- string.format('%s:s=%s:i=%d', params.domain, params.selector, cur_idx))
+ string.format('%s:s=%s:i=%d', params.domain, params.selector, cur_idx))
end
local function prepare_arc_selector(task, sel)
@@ -668,7 +703,6 @@ local function prepare_arc_selector(task, sel)
else
default_arc_cv()
end
-
end
return true
@@ -696,18 +730,17 @@ local function do_sign(task, sign_params)
sign_params.strict_pubkey_check = not settings.allow_pubkey_mismatch
elseif not settings.allow_pubkey_mismatch then
rspamd_logger.errx(task, 'public key for domain %s/%s is not found: %s, skip signing',
- sign_params.domain, sign_params.selector, err)
+ sign_params.domain, sign_params.selector, err)
return
else
rspamd_logger.infox(task, 'public key for domain %s/%s is not found: %s',
- sign_params.domain, sign_params.selector, err)
+ sign_params.domain, sign_params.selector, err)
end
local dret, hdr = dkim_sign(task, sign_params)
if dret then
arc_sign_seal(task, sign_params, hdr)
end
-
end,
forced = true
})
@@ -768,6 +801,31 @@ end
dkim_sign_tools.process_signing_settings(N, settings, opts)
+-- Process ARC-specific maps that aren't handled by dkim_sign_tools
+local lua_maps = require "lua_maps"
+
+if opts.whitelisted_signers_map then
+ settings.whitelisted_signers_map = lua_maps.map_add_from_ucl(opts.whitelisted_signers_map, 'set',
+ 'ARC trusted signers domains')
+ if not settings.whitelisted_signers_map then
+ rspamd_logger.errx(rspamd_config, 'cannot load whitelisted_signers_map')
+ settings.whitelisted_signers_map = nil
+ else
+ rspamd_logger.infox(rspamd_config, 'loaded ARC whitelisted signers map')
+ end
+end
+
+if opts.whitelist then
+ settings.whitelist = lua_maps.map_add_from_ucl(opts.whitelist, 'set',
+ 'ARC domains with broken implementations')
+ if not settings.whitelist then
+ rspamd_logger.errx(rspamd_config, 'cannot load ARC whitelist map')
+ settings.whitelist = nil
+ else
+ rspamd_logger.infox(rspamd_config, 'loaded ARC whitelist map')
+ end
+end
+
if not dkim_sign_tools.validate_signing_settings(settings) then
rspamd_logger.infox(rspamd_config, 'mandatory parameters missing, disable arc signing')
return
@@ -780,7 +838,7 @@ if ar_opts and ar_opts.routines then
if routines['authentication-results'] then
ar_settings = lua_util.override_defaults(ar_settings,
- routines['authentication-results'])
+ routines['authentication-results'])
end
end
@@ -789,7 +847,7 @@ if settings.use_redis then
if not redis_params then
rspamd_logger.errx(rspamd_config, 'no servers are specified, ' ..
- 'but module is configured to load keys from redis, disable arc signing')
+ 'but module is configured to load keys from redis, disable arc signing')
return
end
@@ -845,9 +903,9 @@ if settings.adjust_dmarc and settings.whitelisted_signers_map then
local dmarc_fwd = ar.dmarc
if dmarc_fwd == 'pass' then
rspamd_logger.infox(task, "adjust dmarc reject score as trusted forwarder "
- .. "proved DMARC validity for %s", ar['header.from'])
+ .. "proved DMARC validity for %s", ar['header.from'])
task:adjust_result(sym_to_adjust, 0.1,
- 'ARC trusted')
+ 'ARC trusted')
end
end
end
diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua
index 44ff9dafa..0d78f2272 100644
--- a/src/plugins/lua/bayes_expiry.lua
+++ b/src/plugins/lua/bayes_expiry.lua
@@ -41,32 +41,38 @@ local template = {}
local function check_redis_classifier(cls, cfg)
-- Skip old classifiers
if cls.new_schema then
- local symbol_spam, symbol_ham
+ local class_symbols = {}
+ local class_labels = {}
local expiry = (cls.expiry or cls.expire)
if type(expiry) == 'table' then
expiry = expiry[1]
end
- -- Load symbols from statfiles
+ -- Extract class_labels mapping from classifier config
+ if cls.class_labels then
+ class_labels = cls.class_labels
+ end
+ -- Load symbols from statfiles for multi-class support
local function check_statfile_table(tbl, def_sym)
local symbol = tbl.symbol or def_sym
-
- local spam
- if tbl.spam then
- spam = tbl.spam
- else
- if string.match(symbol:upper(), 'SPAM') then
- spam = true
+ local class_name = tbl.class
+
+ -- Handle legacy spam/ham detection for backward compatibility
+ if not class_name then
+ if tbl.spam ~= nil then
+ class_name = tbl.spam and 'spam' or 'ham'
+ elseif string.match(tostring(symbol):upper(), 'SPAM') then
+ class_name = 'spam'
+ elseif string.match(tostring(symbol):upper(), 'HAM') then
+ class_name = 'ham'
else
- spam = false
+ class_name = def_sym
end
end
- if spam then
- symbol_spam = symbol
- else
- symbol_ham = symbol
+ if class_name then
+ class_symbols[class_name] = symbol
end
end
@@ -87,10 +93,9 @@ local function check_redis_classifier(cls, cfg)
end
end
- if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then
+ if next(class_symbols) == nil or type(expiry) ~= 'number' then
logger.debugm(N, rspamd_config,
- 'disable expiry for classifier %s: no expiry %s',
- symbol_spam, cls)
+ 'disable expiry for classifier: no class symbols or expiry configured')
return
end
-- Now try to load redis_params if needed
@@ -108,17 +113,16 @@ local function check_redis_classifier(cls, cfg)
end
if redis_params['read_only'] then
- logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration',
- symbol_spam)
+ logger.infox(rspamd_config, 'disable expiry for classifier: read only redis configuration')
return
end
- logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry",
- symbol_spam, symbol_ham, expiry)
+ logger.debugm(N, rspamd_config, "enabled expiry for classes %s -> %s expiry",
+ table.concat(lutil.keys(class_symbols), ', '), expiry)
table.insert(settings.classifiers, {
- symbol_spam = symbol_spam,
- symbol_ham = symbol_ham,
+ class_symbols = class_symbols,
+ class_labels = class_labels,
redis_params = redis_params,
expiry = expiry
})
@@ -249,12 +253,11 @@ local expiry_script = [[
local keys = ret[2]
local tokens = {}
- -- Tokens occurrences distribution counters
+ -- Dynamic occurrence tracking for all classes
local occur = {
- ham = {},
- spam = {},
total = {}
}
+ local classes_found = {}
-- Expiry step statistics counters
local nelts, extended, discriminated, sum, sum_squares, common, significant,
@@ -264,24 +267,44 @@ local expiry_script = [[
for _,key in ipairs(keys) do
local t = redis.call('TYPE', key)["ok"]
if t == 'hash' then
- local values = redis.call('HMGET', key, 'H', 'S')
- local ham = tonumber(values[1]) or 0
- local spam = tonumber(values[2]) or 0
+ -- Get all hash fields to support multi-class
+ local hash_data = redis.call('HGETALL', key)
+ local class_counts = {}
+ local total = 0
local ttl = redis.call('TTL', key)
+
+ -- Parse hash data into class counts
+ for i = 1, #hash_data, 2 do
+ local class_label = hash_data[i]
+ local count = tonumber(hash_data[i + 1]) or 0
+ class_counts[class_label] = count
+ total = total + count
+
+ -- Track classes we've seen
+ if not classes_found[class_label] then
+ classes_found[class_label] = true
+ occur[class_label] = {}
+ end
+ end
+
tokens[key] = {
- ham,
- spam,
- ttl
+ class_counts = class_counts,
+ total = total,
+ ttl = ttl
}
- local total = spam + ham
+
sum = sum + total
sum_squares = sum_squares + total * total
nelts = nelts + 1
- for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do
- if tonumber(v) > 19 then v = 20 end
- occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1
+ -- Update occurrence counters for all classes and total
+ for class_label, count in pairs(class_counts) do
+ local bucket = count > 19 and 20 or count
+ occur[class_label][bucket] = (occur[class_label][bucket] or 0) + 1
end
+
+ local total_bucket = total > 19 and 20 or total
+ occur.total[total_bucket] = (occur.total[total_bucket] or 0) + 1
end
end
@@ -293,9 +316,10 @@ local expiry_script = [[
end
for key,token in pairs(tokens) do
- local ham, spam, ttl = token[1], token[2], tonumber(token[3])
+ local class_counts = token.class_counts
+ local total = token.total
+ local ttl = tonumber(token.ttl)
local threshold = mean
- local total = spam + ham
local function set_ttl()
if expire < 0 then
@@ -310,14 +334,39 @@ local expiry_script = [[
return 0
end
- if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then
+ -- Check if token is common (balanced across classes)
+ local is_common = false
+ if total == 0 then
+ is_common = true
+ else
+ -- For multi-class, check if any class dominates significantly
+ local max_count = 0
+ for _, count in pairs(class_counts) do
+ if count > max_count then
+ max_count = count
+ end
+ end
+ -- Token is common if no class has more than (1 - epsilon) of total
+ is_common = (max_count / total) <= (1 - ${epsilon_common})
+ end
+
+ if is_common then
common = common + 1
if ttl > ${common_ttl} then
discriminated = discriminated + 1
redis.call('EXPIRE', key, ${common_ttl})
end
elseif total >= threshold and total > 0 then
- if ham / total > ${significant_factor} or spam / total > ${significant_factor} then
+ -- Check if any class is significant
+ local is_significant = false
+ for _, count in pairs(class_counts) do
+ if count / total > ${significant_factor} then
+ is_significant = true
+ break
+ end
+ end
+
+ if is_significant then
significant = significant + 1
if ttl ~= -1 then
redis.call('PERSIST', key)
@@ -361,33 +410,50 @@ local expiry_script = [[
redis.call('DEL', lock_key)
local occ_distr = {}
- for _,cl in pairs({'ham', 'spam', 'total'}) do
+
+ -- Process all classes found plus total
+ local all_classes = {'total'}
+ for class_label in pairs(classes_found) do
+ table.insert(all_classes, class_label)
+ end
+
+ for _, cl in ipairs(all_classes) do
local occur_key = pattern_sha1 .. '_occurrence_' .. cl
if cursor ~= 0 then
- local n
- for i,v in ipairs(redis.call('HGETALL', occur_key)) do
- if i % 2 == 1 then
- n = tonumber(v)
- else
- occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v
+ local existing_data = redis.call('HGETALL', occur_key)
+ if #existing_data > 0 then
+ for i = 1, #existing_data, 2 do
+ local bucket = tonumber(existing_data[i])
+ local count = tonumber(existing_data[i + 1])
+ if occur[cl] and occur[cl][bucket] then
+ occur[cl][bucket] = occur[cl][bucket] + count
+ elseif occur[cl] then
+ occur[cl][bucket] = count
+ end
end
end
- local str = ''
- if occur[cl][0] ~= nil then
- str = '0:' .. occur[cl][0] .. ','
- end
- for k,v in ipairs(occur[cl]) do
- if k == 20 then k = '>19' end
- str = str .. k .. ':' .. v .. ','
+ if occur[cl] and next(occur[cl]) then
+ local str = ''
+ if occur[cl][0] then
+ str = '0:' .. occur[cl][0] .. ','
+ end
+ for k = 1, 20 do
+ if occur[cl][k] then
+ local label = k == 20 and '>19' or tostring(k)
+ str = str .. label .. ':' .. occur[cl][k] .. ','
+ end
+ end
+ table.insert(occ_distr, cl .. '=' .. str)
+ else
+ table.insert(occ_distr, cl .. '=no_data')
end
- table.insert(occ_distr, str)
else
redis.call('DEL', occur_key)
end
- if next(occur[cl]) ~= nil then
+ if occur[cl] and next(occur[cl]) then
redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl])))
end
end
@@ -446,8 +512,8 @@ local function expire_step(cls, ev_base, worker)
'%s infrequent (%s %s), %s mean, %s std',
lutil.unpack(d))
if cycle then
- for i, cl in ipairs({ 'in ham', 'in spam', 'total' }) do
- logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i])
+ for _, distr_info in ipairs(occ_distr) do
+ logger.infox(rspamd_config, 'tokens occurrences: {%s}', distr_info)
end
end
end