]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Reputation: Rework get function and scores calculations
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 15 May 2019 14:07:16 +0000 (15:07 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 15 May 2019 14:07:16 +0000 (15:07 +0100)
src/plugins/lua/reputation.lua

index f1062dbaa4c2af88d58768bf188d262c70f8c9c4..ebfe2c04e44d240de8c95a40ee2012cf36f27ebe 100644 (file)
@@ -37,21 +37,34 @@ local redis_params = nil
 local default_expiry = 864000 -- 10 day by default
 local default_prefix = 'RR:' -- Rspamd Reputation
 
+local tanh = math.tanh or rspamd_util.tanh
+
+local reject_threshold = rspamd_config:get_action('reject') or 10.0
 
 -- Get reputation from ham/spam/probable hits
-local function generic_reputation_calc(token, rule, mult)
+local function generic_reputation_calc(token, rule, mult, task)
   local cfg = rule.selector.config or E
 
   if cfg.score_calc_func then
     return cfg.score_calc_func(rule, token, mult)
   end
 
-  if token[1] < cfg.lower_bound then return 0 end
+  if tonumber(token[1]) < cfg.lower_bound then
+    lua_util.debugm(N, task, "not enough matches %s < %s for rule %s",
+        token[1], cfg.lower_bound, rule.symbol)
+    return 0
+  end
 
-  local score = fun.foldl(function(acc, v)
+  -- Get average score
+  local avg_score = fun.foldl(function(acc, v)
     return acc + v
   end, 0.0, fun.map(tonumber, token[2])) / #token[2]
 
+  -- Apply function tanh(x / reject_score * atanh(0.95) - atanh(0.5))
+  --                                        1.83178       0.5493
+  local score = tanh(avg_score / reject_threshold * 1.83178 - 0.5493) * mult
+  lua_util.debugm(N, task, "got generic average score %s -> %s for rule %s",
+      avg_score, score, rule.symbol)
   return score
 end
 
@@ -158,9 +171,9 @@ local function dkim_reputation_filter(task, rule)
     if nchecked == #requests then
       for k,v in pairs(results) do
         if requests[k] == 'a' then
-          rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0)
+          rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0, task)
         elseif requests[k] == 'r' then
-          rep_rejected = rep_rejected + generic_reputation_calc(v, rule, 1.0)
+          rep_rejected = rep_rejected + generic_reputation_calc(v, rule, 1.0, task)
         end
       end
 
@@ -293,7 +306,8 @@ local function url_reputation_filter(task, rule)
       if mhits > 0 then
         local score = 0
         for k,v in pairs(results) do
-          score = score + generic_reputation_calc(v, rule, requests[k][2] / mhits)
+          score = score + generic_reputation_calc(v, rule,
+              requests[k][2] / mhits, task)
         end
 
         if math.abs(score) > 1e-3 then
@@ -384,17 +398,19 @@ local function ip_reputation_filter(task, rule)
     local description_t = {}
 
     if asn_stats then
-      local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn)
+      local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn, task)
       score = score + asn_score
       table.insert(description_t, string.format('asn: %s(%.2f)', asn, asn_score))
     end
     if country_stats then
-      local country_score = generic_reputation_calc(country_stats, rule, cfg.scores.country)
+      local country_score = generic_reputation_calc(country_stats, rule,
+          cfg.scores.country, task)
       score = score + country_score
       table.insert(description_t, string.format('country: %s(%.2f)', country, country_score))
     end
     if ip_stats then
-      local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip)
+      local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip,
+        task)
       score = score + ip_score
       table.insert(description_t, string.format('ip: %s(%.2f)', ip, ip_score))
     end
@@ -518,7 +534,7 @@ local function spf_reputation_filter(task, rule)
 
   local function tokens_cb(err, token, values)
     if values then
-      local score = generic_reputation_calc(values, rule, 1.0)
+      local score = generic_reputation_calc(values, rule, 1.0, task)
 
       if math.abs(score) > 1e-3 then
         -- TODO: add description
@@ -541,7 +557,7 @@ local function spf_reputation_idempotent(task, rule)
   local hkey = cr.create(spf_record):base32():sub(1, 32)
 
   lua_util.debugm(N, task, 'set spf record %s -> %s = %s',
-      spf_record, hkey, token)
+      spf_record, hkey, sc)
   rule.backend.set_token(task, rule, hkey, sc)
 end
 
@@ -599,7 +615,7 @@ local function generic_reputation_filter(task, rule)
 
   local function tokens_cb(err, token, values)
     if values then
-      local score = generic_reputation_calc(values, rule, 1.0)
+      local score = generic_reputation_calc(values, rule, 1.0, task)
 
       if math.abs(score) > 1e-3 then
         -- TODO: add description
@@ -819,7 +835,7 @@ local function reputation_redis_init(rule, cfg, ev_base, worker)
   {% endfor %}
   end
 
-  return results,cnt
+  return {cnt,results}
   ]]
 
   local get_script = lua_util.jinja_template(redis_get_script_tpl,
@@ -884,16 +900,9 @@ local function reputation_redis_get_token(task, rule, token, continuation_cb)
   local function redis_get_cb(err, data)
     if data then
       if type(data) == 'table' then
-        local values = {}
-        for i=1,#data,2 do
-          local ndata = tonumber(data[i + 1])
-          if ndata then
-            values[data[i]] = ndata
-          end
-        end
         lua_util.debugm(N, task, 'rule %s - got values for key %s -> %s',
-            rule['symbol'], key, values)
-        continuation_cb(nil, key, values)
+            rule['symbol'], key, data)
+        continuation_cb(nil, key, data)
       else
         rspamd_logger.errx(task, 'rule %s - invalid type while getting reputation keys %s: %s',
           rule['symbol'], key, type(data))
@@ -972,11 +981,6 @@ local backends = {
       expiry = default_expiry,
       prefix = default_prefix,
       buckets = {
-        {
-          time = 60 * 60,
-          name = '1h',
-          mult = 1.5,
-        },
         {
           time = 60 * 60 * 24 * 30,
           name = '1m',