]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Adaptive ratelimits 897/head
authorAndrew Lewis <nerf@judo.za.org>
Fri, 26 Aug 2016 08:41:26 +0000 (10:41 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Fri, 26 Aug 2016 08:41:26 +0000 (10:41 +0200)
 - Also per-IP and per-ASN ratelimits
 - Minor rework of some parts

src/plugins/lua/asn.lua
src/plugins/lua/ip_score.lua
src/plugins/lua/ratelimit.lua

index 28cd431b9153d879e386979070deb799751350f6..25b684dc47ebd34c9e1aa878549b32c260c135be 100644 (file)
@@ -96,5 +96,6 @@ if configure_asn_module() then
     name = 'ASN_CHECK',
     type = 'prefilter',
     callback = asn_check,
+    priority = 10,
   })
 end
index 677ed12a5cab45e3501a5c35a63ce12a1db8378c..93d01f6c75b95eb53ec463bf04336f3ee53a2942 100644 (file)
@@ -50,7 +50,7 @@ local options = {
   metric = 'default',
   min_score = nil,
   max_score = nil,
-  score_divisor = nil
+  score_divisor = 1,
 }
 
 local asn_re = rspamd_regexp.create_cached("[\\|\\s]")
@@ -138,11 +138,7 @@ local ip_score_set = function(task)
     score_mult = 0
   end
 
-  if options['score_divisor'] then
-    score = score_mult * rspamd_util.tanh (2.718281 * (score/options['score_divisor']))
-  else
-    score = score_mult * rspamd_util.tanh (2.718281 * score)
-  end
+  score = score_mult * rspamd_util.tanh (2.718281 * (score/options['score_divisor']))
 
   local hkey = ip_score_hash_key(asn, country, ipnet, ip)
   local upstream,ret
@@ -341,6 +337,6 @@ if redis_params then
   })
   rspamd_config:register_symbol({
     name = options['symbol'],
-    callback = ip_score_check
+    callback = ip_score_check,
   })
 end
index 1b0ec4f7ade47df20c262e569c179211fbb071ee..e261a927562e9bac52cf32f7be7337aab4c76ad9 100644 (file)
@@ -18,24 +18,28 @@ limitations under the License.
 
 -- Default settings for limits, 1-st member is burst, second is rate and the third is numeric type
 local settings = {
-  -- Limit for all mail per recipient (burst 100, rate 2 per minute)
+  -- Limit mail per ASN (rate 12 per minute)
+  asn = {0, 0.199999998},
+  -- Limit mail per source IP (rate 6 per minute)
+  ip = {0, 0.099999999},
+  -- Limit for all mail per recipient (rate 2 per minute)
   to = {0, 0.033333333},
-  -- Limit for all mail per one source ip (burst 30, rate 1.5 per minute)
+  -- Limit for all mail to a recipient per source ip (rate 1.5 per minute)
   to_ip = {0, 0.025},
-  -- Limit for all mail per one source ip and from address (burst 20, rate 1 per minute)
+  -- Limit for all mail per recipient/sender/source ip triplet (rate 1 per minute)
   to_ip_from = {0, 0.01666666667},
 
-  -- Limit for all bounce mail (burst 10, rate 2 per hour)
+  -- Limit for all bounce mail (rate 2 per hour)
   bounce_to = {0, 0.000555556},
-  -- Limit for bounce mail per one source ip (burst 5, rate 1 per hour)
+  -- Limit for bounce mail per one source ip (rate 1 per hour)
   bounce_to_ip = {0, 0.000277778},
 
-  -- Limit for all mail per user (authuser) (burst 20, rate 1 per minute)
+  -- Limit for all mail per user (authuser) (rate 1 per minute)
   user = {0, 0.01666666667}
 }
 -- Senders that are considered as bounce
 local bounce_senders = {'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon'}
--- Do not check ratelimits for these senders
+-- Do not check ratelimits for these recipients
 local whitelisted_rcpts = {'postmaster', 'mailer-daemon'}
 local whitelisted_ip
 local max_rcpt = 5
@@ -43,13 +47,17 @@ local redis_params
 local ratelimit_symbol
 -- Do not delay mail after 1 day
 local max_delay = 24 * 3600
+local use_ip_score = false
+local rl_prefix = 'rl'
+local ip_score_lower_bound = 10
+local ip_score_ham_multiplier = 1.1
+local ip_score_spam_divisor = 1.1
 
 local rspamd_logger = require "rspamd_logger"
 local rspamd_redis = require "rspamd_redis"
 local upstream_list = require "rspamd_upstream_list"
 local rspamd_util = require "rspamd_util"
 local fun = require "fun"
---local dumper = require 'pl.pretty'.dump
 
 --- Parse atime and bucket of limit
 local function parse_limits(data)
@@ -83,13 +91,21 @@ local function parse_limits(data)
   end):totable()
 end
 
-local function generate_format_string(args, is_set)
-  if is_set then
-    return 'MSET'
-    --return fun.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args)
+local function resize_element(x_score, x_total, element)
+  local x_ip_score
+  if x_total < ip_score_lower_bound or x_total <= 0 then
+    x_score = 1
+  else
+    x_score = x_score / x_total
+  end
+  if x_score > 0 then
+    x_ip_score = x_score / ip_score_spam_divisor
+    element = element * rspamd_util.tanh(2.718281 * x_ip_score)
+  elseif x_score < 0 then
+    x_ip_score = (1 + ((x_score / x_total) * -1)) * ip_score_ham_multiplier
+    element = element * x_ip_score
   end
-  return 'MGET'
-  --return fun.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args)
+  return element
 end
 
 --- Check specific limit inside redis
@@ -99,52 +115,88 @@ local function check_limits(task, args)
   local ret,upstream
   --- Called when value is got from server
   local function rate_get_cb(task, err, data)
-    if data then
-      local ntime = rspamd_util.get_time()
-
-      fun.each(function(elt, limit)
-        local bucket = elt[2]
-        local rate = limit[2]
-        local threshold = limit[1]
-        local atime = elt[1]
-        local ctime = elt[3]
+    if err then
+      rspamd_logger.infox(task, 'got error while getting limit: %1', err)
+      upstream:fail()
+    end
+    if not data then return end
+    local ntime = rspamd_util.get_time()
+    local asn_score,total_asn,
+      country_score,total_country,
+      ipnet_score,total_ipnet,
+      ip_score, total_ip
+    if use_ip_score then
+      asn_score,total_asn,
+        country_score,total_country,
+        ipnet_score,total_ipnet,
+        ip_score, total_ip = task:get_mempool():get_variable('ip_score',
+        'double,double,double,double,double,double,double,double')
+    end
 
-        if atime == 0 then return end
+    fun.each(function(elt, limit, rtype)
+      local bucket = elt[2]
+      local rate = limit[2]
+      local threshold = limit[1]
+      local atime = elt[1]
+      local ctime = elt[3]
+
+      if atime == 0 then return end
+
+      if use_ip_score then
+        if rtype == 'asn' then
+          bucket = resize_element(asn_score, total_asn, bucket)
+          rate = resize_element(asn_score, total_asn, rate)
+        elseif rtype == 'ip' or rtype == 'to_ip' or rtype == 'to_ip_from'
+          or rtype == 'bounce_to_ip' then
+          if total_ip > ip_score_lower_bound then
+            bucket = resize_element(ip_score, total_ip, bucket)
+            rate = resize_element(ip_score, total_ip, rate)
+          elseif total_ipnet > ip_score_lower_bound then
+            bucket = resize_element(ipnet_score, total_ipnet, bucket)
+            rate = resize_element(ipnet_score, total_ipnet, rate)
+          elseif total_asn > ip_score_lower_bound then
+            bucket = resize_element(asn_score, total_asn, bucket)
+            rate = resize_element(asn_score, total_asn, rate)
+          elseif total_country > ip_score_lower_bound then
+            bucket = resize_element(country_score, total_country, bucket)
+            rate = resize_element(country_score, total_country, rate)
+          else
+            bucket = resize_element(ip_score, total_ip, bucket)
+            rate = resize_element(ip_score, total_ip, rate)
+          end
+        end
+      end
 
-        if atime - ctime > max_delay then
-          rspamd_logger.infox(task, 'limit is too old: %1 seconds; ignore it',
-            atime - ctime)
-        else
-          bucket = bucket - rate * (ntime - atime);
-          if bucket > 0 then
-            if ratelimit_symbol then
-              local mult = 2 * rspamd_util.tanh(bucket / (threshold * 2))
-
-              if mult > 0.5 then
-                task:insert_result(ratelimit_symbol, mult,
-                  tostring(mult))
-              end
-            else
-              if bucket > threshold then
-                task:set_pre_result('soft reject', 'Ratelimit exceeded')
-              end
+      if atime - ctime > max_delay then
+        rspamd_logger.infox(task, 'limit is too old: %1 seconds; ignore it',
+          atime - ctime)
+      else
+        bucket = bucket - rate * (ntime - atime);
+        if bucket > 0 then
+          if ratelimit_symbol then
+            local mult = 2 * rspamd_util.tanh(bucket / (threshold * 2))
+
+            if mult > 0.5 then
+              task:insert_result(ratelimit_symbol, mult,
+                tostring(mult))
+            end
+          else
+            if bucket > threshold then
+              task:set_pre_result('soft reject', 'Ratelimit exceeded')
             end
           end
         end
-      end, fun.zip(parse_limits(data), fun.map(function(a) return a[1] end, args)))
-    elseif err then
-      rspamd_logger.infox(task, 'got error while getting limit: %1', err)
-      upstream:fail()
-    end
+      end
+    end, fun.zip(parse_limits(data), fun.map(function(a) return a[1] end, args),
+      fun.map(function(a) return rspamd_str_split(a[2], ":")[2] end, args)))
   end
 
-  local cmd = generate_format_string(args, false)
   ret,_,upstream = rspamd_redis_make_request(task,
     redis_params, -- connect params
     key, -- hash key
     false, -- is write
     rate_get_cb, --callback
-    cmd, -- command
+    'mget', -- command
     fun.totable(fun.map(function(l) return l[2] end, args)) -- arguments
   )
 end
@@ -163,86 +215,91 @@ local function set_limits(task, args)
     end
   end
   local function rate_get_cb(task, err, data)
-    if data then
-      local ntime = rspamd_util.get_time()
-      local values = {}
-      fun.each(function(elt, limit)
-        local bucket = elt[2]
-        local rate = limit[1][2]
-        local threshold = limit[1][1]
-        local atime = elt[1]
-        local ctime = elt[3]
-
-        if atime - ctime > max_delay then
-          rspamd_logger.infox(task, 'limit is too old: %1 seconds; start it over',
-            atime - ctime)
-          bucket = 1
-          ctime = ntime
-          atime = ntime
-        else
-          if bucket > 0 then
-            bucket = bucket - rate * (ntime - atime) + 1;
-            if bucket < 0 then
-              bucket = 1
-            end
-          else
+    if err then
+      rspamd_logger.infox(task, 'got error while setting limit: %1', err)
+      upstream:fail()
+    end
+    if not data then return end
+    local ntime = rspamd_util.get_time()
+    local values = {}
+    fun.each(function(elt, limit)
+      local bucket = elt[2]
+      local rate = limit[1][2]
+      local threshold = limit[1][1]
+      local atime = elt[1]
+      local ctime = elt[3]
+
+      if atime - ctime > max_delay then
+        rspamd_logger.infox(task, 'limit is too old: %1 seconds; start it over',
+          atime - ctime)
+        bucket = 1
+        ctime = ntime
+        atime = ntime
+      else
+        if bucket > 0 then
+          bucket = bucket - rate * (ntime - atime) + 1;
+          if bucket < 0 then
             bucket = 1
           end
-        end
-
-        if ctime == 0 then ctime = ntime end
-
-        local lstr = string.format('%.3f:%.3f:%.3f', ntime, bucket, ctime)
-        table.insert(values, {limit[2], max_delay, lstr})
-      end, fun.zip(parse_limits(data), fun.iter(args)))
-
-      if #values > 0 then
-        local conn
-        ret,conn,upstream = rspamd_redis_make_request(task,
-          redis_params, -- connect params
-          key, -- hash key
-          true, -- is write
-          rate_set_cb, --callback
-          'setex', -- command
-          values[1] -- arguments
-        )
-
-        if conn then
-          fun.each(function(v)
-            conn:add_cmd('setex', v)
-          end, fun.drop_n(1, values))
         else
-          rspamd_logger.infox(task, 'got error while connecting to redis: %1', addr)
-          upstream:fail()
+          bucket = 1
         end
-      elseif err then
-        rspamd_logger.infox(task, 'got error while setting limit: %1', err)
+      end
+
+      if ctime == 0 then ctime = ntime end
+
+      local lstr = string.format('%.3f:%.3f:%.3f', ntime, bucket, ctime)
+      table.insert(values, {limit[2], max_delay, lstr})
+    end, fun.zip(parse_limits(data), fun.iter(args)))
+
+    if #values > 0 then
+      local conn
+      ret,conn,upstream = rspamd_redis_make_request(task,
+        redis_params, -- connect params
+        key, -- hash key
+        true, -- is write
+        rate_set_cb, --callback
+        'setex', -- command
+        values[1] -- arguments
+      )
+
+      if conn then
+        fun.each(function(v)
+          conn:add_cmd('setex', v)
+        end, fun.drop_n(1, values))
+      else
+        rspamd_logger.infox(task, 'got error while connecting to redis: %1', addr)
         upstream:fail()
       end
     end
   end
 
-  local cmd = generate_format_string(args, false)
   ret,_,upstream = rspamd_redis_make_request(task,
     redis_params, -- connect params
     key, -- hash key
     false, -- is write
     rate_get_cb, --callback
-    cmd, -- command
+    'mget', -- command
     fun.totable(fun.map(function(l) return l[2] end, args)) -- arguments
   )
 end
 
 --- Make rate key
-local function make_rate_key(from, to, ip)
-  if from and ip and ip:is_valid() then
-    return string.format('%s:%s:%s', from, to, ip:to_string())
-  elseif from then
-    return string.format('%s:%s', from, to)
-  elseif ip and ip:is_valid() then
-    return string.format('%s:%s', to, ip:to_string())
-  elseif to then
-    return to
+local function make_rate_key(rtype, args)
+  if rtype == 'to_ip_from' and args['from'] and args['to'] and args['ip'] and args['ip']:is_valid() then
+    return string.format('%s:%s:%s:%s:%s', rl_prefix, rtype, args['from'], args['to'], args['ip']:to_string())
+  elseif rtype == 'to_ip' and args['to'] and args['ip'] and args['ip']:is_valid() then
+    return string.format('%s:%s:%s:%s', rl_prefix, rtype, args['to'], args['ip']:to_string())
+  elseif rtype == 'to' and args['to'] then
+    return string.format('%s:%s:%s', rl_prefix, rtype, args['to'])
+  elseif rtype == 'bounce_to' and args['to'] then
+    return string.format('%s:%s:%s', rl_prefix, rtype, args['to'])
+  elseif rtype == 'bounce_to_ip' and args['to'] and args['ip'] and args['ip']:is_valid() then
+    return string.format('%s:%s:%s:%s', rl_prefix, rtype, args['to'], args['ip']:to_string())
+  elseif rtype == 'asn' and args['asn'] then
+    return string.format('%s:%s:%s', rl_prefix, rtype, args['asn'])
+  elseif rtype == 'user' and args['user'] then
+    return string.format('%s:%s:%s', rl_prefix, rtype, args['user'])
   else
     return nil
   end
@@ -289,7 +346,11 @@ local function rate_test_set(task, func)
   -- Get user (authuser)
   local auser = task:get_user()
   if auser and settings['user'][1] > 0 then
-    table.insert(args, {settings['user'], make_rate_key (auser, '<auth>', nil)})
+    table.insert(args, {settings['user'], make_rate_key ('user', {['user'] = auser}) })
+  end
+  local asn
+  if settings['asn'][1] > 0 then
+    asn = task:get_mempool():get_variable('asn')
   end
 
   local is_bounce = check_bounce(from_user)
@@ -298,21 +359,27 @@ local function rate_test_set(task, func)
     fun.each(function(r)
       if is_bounce then
         if settings['bounce_to'][1] > 0 then
-          table.insert(args, { settings['bounce_to'], make_rate_key('<>', r['addr'], nil) })
+          table.insert(args, { settings['bounce_to'], make_rate_key('bounce_to', {['to'] = r['addr']}) })
         end
         if ip and settings['bounce_to_ip'][1] > 0 then
-          table.insert(args, { settings['bounce_to_ip'], make_rate_key('<>', r['addr'], ip) })
+          table.insert(args, { settings['bounce_to_ip'], make_rate_key('bounce_to_ip', {['to'] = r['addr'], ['ip'] = ip}) })
         end
       end
       if settings['to'][1] > 0 then
-        table.insert(args, { settings['to'], make_rate_key(nil, r['addr'], nil) })
+        table.insert(args, { settings['to'], make_rate_key('to', {['to'] = r['addr']}) })
       end
       if ip then
         if settings['to_ip'][1] > 0 then
-          table.insert(args, { settings['to_ip'], make_rate_key(nil, r['addr'], ip) })
+          table.insert(args, { settings['to_ip'], make_rate_key('to_ip', {['to'] = r['addr'], ['ip'] = ip}) })
         end
         if settings['to_ip_from'][1] > 0 then
-          table.insert(args, { settings['to_ip_from'], make_rate_key(from_addr, r['addr'], ip) })
+          table.insert(args, { settings['to_ip_from'], make_rate_key('to_ip_from', {['from'] = from_addr, ['to'] = r['addr'], ['ip'] = ip}) })
+        end
+        if settings['ip'][1] > 0 then
+          table.insert(args, { settings['ip'], make_rate_key('ip', {['ip'] = ip}) })
+        end
+        if asn and settings['asn'][1] > 0 then
+          table.insert(args, { settings['asn'], make_rate_key('asn', {['asn'] = asn}) })
         end
       end
     end, rcpts)
@@ -359,12 +426,16 @@ local function parse_limit(str)
     set_limit(settings['bounce_to_ip'], params[2], params[3])
   elseif params[1] == 'user' then
     set_limit(settings['user'], params[2], params[3])
+  elseif params[1] == 'ip' then
+    set_limit(settings['ip'], params[2], params[3])
+  elseif params[1] == 'asn' then
+    set_limit(settings['asn'], params[2], params[3])
   else
     rspamd_logger.errx(rspamd_config, 'invalid limit type: ' .. params[1])
   end
 end
 
-local opts =  rspamd_config:get_all_opt('ratelimit')
+local opts = rspamd_config:get_all_opt('ratelimit')
 if opts then
   local rates = opts['limit']
   if rates and type(rates) == 'table' then
@@ -412,24 +483,39 @@ if opts then
     max_rcpt = tonumber(opts['max_delay'])
   end
 
+  if opts['use_ip_score'] then
+    use_ip_score = true
+    local ip_score_opts = rspamd_config:get_all_opt('ip_score')
+    if ip_score_opts and ip_score_opts['lower_bound'] then
+      ip_score_lower_bound = ip_score_opts['lower_bound']
+    end
+  end
+
   redis_params = rspamd_parse_redis_server('ratelimit')
   if not redis_params then
     rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')
   else
-    if not ratelimit_symbol then
-       rspamd_config:register_symbol({
+    if not ratelimit_symbol and not use_ip_score then
+      rspamd_config:register_symbol({
         name = 'RATELIMIT_CHECK',
-        type = 'prefilter',
         callback = rate_test,
+        type = 'prefilter',
+        priority = 10,
       })
     else
-      rspamd_config:register_symbol({
+      if not ratelimit_symbol then
+        symbol = 'RATELIMIT_CHECK'
+      else
+        symbol = ratelimit_symbol
+      end
+      local id = rspamd_config:register_symbol({
         name = ratelimit_symbol,
         callback = rate_test,
-        flags = 'empty'
       })
+      if use_ip_score then
+        rspamd_config:register_dependency(id, 'IP_SCORE')
+      end
     end
-
     rspamd_config:register_symbol({
       name = 'RATELIMIT_SET',
       type = 'postfilter',