]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Rework ratelimits configuration
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 13 Jul 2018 14:22:55 +0000 (15:22 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 13 Jul 2018 16:24:42 +0000 (17:24 +0100)
src/plugins/lua/ratelimit.lua

index 654c300cac7259b8e9844a138876556024bcef4b..c09ebc148ee603705d970fab0273580ecddb15de 100644 (file)
@@ -227,6 +227,50 @@ local function parse_string_limit(lim, no_error)
   return nil
 end
 
+local function parse_limit(name, data)
+  local buckets = {}
+  if type(data) == 'table' then
+    -- 3 cases here:
+    --  * old limit in format [burst, rate]
+    --  * vector of strings in Andrew's string format
+    --  * proper bucket table
+    if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
+      -- Old style ratelimit
+      rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
+      if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
+        table.insert(buckets, {
+          burst = data[1],
+          rate = data[2]
+        })
+      elseif data[1] ~= 0 then
+        rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
+      else
+        rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
+      end
+    else
+      -- Recursively map parse_limit and flatten the list
+      fun.each(function(l)
+        -- Flatten list
+        for _,b in ipairs(l) do table.insert(buckets, b) end
+      end, fun.map(function(d) return parse_limit(d, name) end, data))
+    end
+  elseif type(data) == 'string' then
+    local rep_rate, burst = parse_string_limit(data)
+
+    if rep_rate and burst then
+      table.insert(buckets, {
+        burst = burst,
+        rate = 1.0 / rep_rate -- reciprocal
+      })
+    end
+  end
+
+  -- Filter valid
+  return fun.totable(fun.filter(function(val)
+    return type(val.bucket) == 'number' and type(val.rate) == 'number'
+  end, buckets))
+end
+
 --- Check whether this addr is bounce
 local function check_bounce(from)
   return fun.any(function(b) return b == from end, settings.bounce_senders)
@@ -316,6 +360,46 @@ local function gen_rate_key(task, rtype, bucket)
   return table.concat(key_t, ":")
 end
 
+local function make_prefix(redis_key, name, bucket)
+  local hash_len = 24
+  if hash_len > #redis_key then hash_len = #redis_key end
+  local hash = settings.prefix ..
+      string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)
+  -- Fill defaults
+  if not bucket.spam_factor_rate then
+    bucket.spam_factor_rate = settings.spam_factor_rate
+  end
+  if not bucket.ham_factor_rate then
+    bucket.ham_factor_rate = settings.ham_factor_rate
+  end
+  if not bucket.spam_factor_burst then
+    bucket.spam_factor_burst = settings.spam_factor_burst
+  end
+  if not bucket.ham_factor_burst then
+    bucket.ham_factor_burst = settings.ham_factor_burst
+  end
+
+  return {
+    bucket = bucket,
+    name = name,
+    hash = hash
+  }
+end
+
+local function limit_to_prefixes(task, k, v, prefixes)
+  local n = 0
+  for _,bucket in ipairs(v) do
+    local prefix = gen_rate_key(task, k, bucket)
+
+    if prefix then
+      prefixes[prefix] = make_prefix(prefix, k, bucket)
+      n = n + 1
+    end
+  end
+
+  return n
+end
+
 local function ratelimit_cb(task)
   if not settings.allow_local and
           rspamd_lua_utils.is_rspamc_or_controller(task) then return end
@@ -355,22 +439,7 @@ local function ratelimit_cb(task)
   local nprefixes = 0
 
   for k,v in pairs(settings.limits) do
-    for _,bucket in ipairs(v) do
-      local prefix = gen_rate_key(task, k, bucket)
-
-      if prefix then
-        local hash_len = 24
-        if hash_len > #prefix then hash_len = #prefix end
-        local hash = settings.prefix ..
-                string.sub(rspamd_hash.create(prefix):base32(), 1, hash_len)
-        prefixes[prefix] = {
-          bucket = bucket,
-          name = k,
-          hash = hash
-        }
-        nprefixes = nprefixes + 1
-      end
-    end
+    nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)
   end
 
   local function gen_check_cb(prefix, bucket, lim_name)
@@ -414,7 +483,7 @@ local function ratelimit_cb(task)
 
     for pr,value in pairs(prefixes) do
       local bucket = value.bucket
-      local rate = (1.0 / bucket[1]) / 1000.0 -- Leak rate in messages/ms
+      local rate = (bucket[1]) / 1000.0 -- Leak rate in messages/ms
       rspamd_logger.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",
           value.name, pr, value.hash, bucket[2], bucket[1])
       lua_redis.exec_redis_script(bucket_check_id,
@@ -489,31 +558,10 @@ if opts then
   if opts['rates'] and type(opts['rates']) == 'table' then
     -- new way of setting limits
     fun.each(function(t, lim)
-      if type(lim) == 'table' then
-        settings.limits[t] = {}
-        if #lim == 2 and tonumber(lim[1]) and tonumber(lim[2]) then
-          -- Old style ratelimit
-          rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', t)
-          if tonumber(lim[1]) > 0 and tonumber(lim[2]) > 0 then
-            table.insert(settings.limits[t], {1.0/lim[2], lim[1]})
-          elseif lim[1] ~= 0 then
-            rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', t)
-          else
-            rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', t)
-          end
-        else
-          fun.each(function(l)
-            local plim, size = parse_string_limit(l)
-            if plim then
-              table.insert(settings.limits[t], {plim, size})
-            end
-          end, lim)
-        end
-      elseif type(lim) == 'string' then
-        local plim, size = parse_string_limit(lim)
-        if plim then
-          settings.limits[t] = { {plim, size} }
-        end
+      local buckets = parse_limit(t, lim)
+
+      if buckets and #buckets > 0 then
+        settings.limits[t] = buckets
       end
     end, opts['rates'])
   end