diff options
-rw-r--r-- | src/plugins/lua/reputation.lua | 58 |
1 files changed, 31 insertions, 27 deletions
diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua index f1062dbaa..ebfe2c04e 100644 --- a/src/plugins/lua/reputation.lua +++ b/src/plugins/lua/reputation.lua @@ -37,21 +37,34 @@ local redis_params = nil local default_expiry = 864000 -- 10 day by default local default_prefix = 'RR:' -- Rspamd Reputation +local tanh = math.tanh or rspamd_util.tanh + +local reject_threshold = rspamd_config:get_action('reject') or 10.0 -- Get reputation from ham/spam/probable hits -local function generic_reputation_calc(token, rule, mult) +local function generic_reputation_calc(token, rule, mult, task) local cfg = rule.selector.config or E if cfg.score_calc_func then return cfg.score_calc_func(rule, token, mult) end - if token[1] < cfg.lower_bound then return 0 end + if tonumber(token[1]) < cfg.lower_bound then + lua_util.debugm(N, task, "not enough matches %s < %s for rule %s", + token[1], cfg.lower_bound, rule.symbol) + return 0 + end - local score = fun.foldl(function(acc, v) + -- Get average score + local avg_score = fun.foldl(function(acc, v) return acc + v end, 0.0, fun.map(tonumber, token[2])) / #token[2] + -- Apply function tanh(x / reject_score * atanh(0.95) - atanh(0.5)) + -- 1.83178 0.5493 + local score = tanh(avg_score / reject_threshold * 1.83178 - 0.5493) * mult + lua_util.debugm(N, task, "got generic average score %s -> %s for rule %s", + avg_score, score, rule.symbol) return score end @@ -158,9 +171,9 @@ local function dkim_reputation_filter(task, rule) if nchecked == #requests then for k,v in pairs(results) do if requests[k] == 'a' then - rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0) + rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0, task) elseif requests[k] == 'r' then - rep_rejected = rep_rejected + generic_reputation_calc(v, rule, 1.0) + rep_rejected = rep_rejected + generic_reputation_calc(v, rule, 1.0, task) end end @@ -293,7 +306,8 @@ local function url_reputation_filter(task, rule) if mhits > 0 then local score = 0 for k,v in pairs(results) do - score = score + generic_reputation_calc(v, rule, requests[k][2] / mhits) + score = score + generic_reputation_calc(v, rule, + requests[k][2] / mhits, task) end if math.abs(score) > 1e-3 then @@ -384,17 +398,19 @@ local function ip_reputation_filter(task, rule) local description_t = {} if asn_stats then - local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn) + local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn, task) score = score + asn_score table.insert(description_t, string.format('asn: %s(%.2f)', asn, asn_score)) end if country_stats then - local country_score = generic_reputation_calc(country_stats, rule, cfg.scores.country) + local country_score = generic_reputation_calc(country_stats, rule, + cfg.scores.country, task) score = score + country_score table.insert(description_t, string.format('country: %s(%.2f)', country, country_score)) end if ip_stats then - local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip) + local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip, + task) score = score + ip_score table.insert(description_t, string.format('ip: %s(%.2f)', ip, ip_score)) end @@ -518,7 +534,7 @@ local function spf_reputation_filter(task, rule) local function tokens_cb(err, token, values) if values then - local score = generic_reputation_calc(values, rule, 1.0) + local score = generic_reputation_calc(values, rule, 1.0, task) if math.abs(score) > 1e-3 then -- TODO: add description @@ -541,7 +557,7 @@ local function spf_reputation_idempotent(task, rule) local hkey = cr.create(spf_record):base32():sub(1, 32) lua_util.debugm(N, task, 'set spf record %s -> %s = %s', - spf_record, hkey, token) + spf_record, hkey, sc) rule.backend.set_token(task, rule, hkey, sc) end @@ -599,7 +615,7 @@ local function generic_reputation_filter(task, rule) local function tokens_cb(err, token, values) if values then - local score = generic_reputation_calc(values, rule, 1.0) + local score = generic_reputation_calc(values, rule, 1.0, task) if math.abs(score) > 1e-3 then -- TODO: add description @@ -819,7 +835,7 @@ local function reputation_redis_init(rule, cfg, ev_base, worker) {% endfor %} end - return results,cnt + return {cnt,results} ]] local get_script = lua_util.jinja_template(redis_get_script_tpl, @@ -884,16 +900,9 @@ local function reputation_redis_get_token(task, rule, token, continuation_cb) local function redis_get_cb(err, data) if data then if type(data) == 'table' then - local values = {} - for i=1,#data,2 do - local ndata = tonumber(data[i + 1]) - if ndata then - values[data[i]] = ndata - end - end lua_util.debugm(N, task, 'rule %s - got values for key %s -> %s', - rule['symbol'], key, values) - continuation_cb(nil, key, values) + rule['symbol'], key, data) + continuation_cb(nil, key, data) else rspamd_logger.errx(task, 'rule %s - invalid type while getting reputation keys %s: %s', rule['symbol'], key, type(data)) @@ -973,11 +982,6 @@ local backends = { prefix = default_prefix, buckets = { { - time = 60 * 60, - name = '1h', - mult = 1.5, - }, - { time = 60 * 60 * 24 * 30, name = '1m', mult = 1.0, |