]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Support friendly rate specification format in user-defined ratelimits
authorAndrew Lewis <nerf@judo.za.org>
Fri, 9 Jun 2017 15:58:26 +0000 (17:58 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Fri, 9 Jun 2017 15:58:26 +0000 (17:58 +0200)
src/plugins/lua/ratelimit.lua

index 7e1043caf818c6c1d0a3165005614ff646127d19..02e9d4d6f18ccaccb93bf20320f622337149fe33 100644 (file)
@@ -52,6 +52,68 @@ local fun = require "fun"
 
 local user_keywords = {'user'}
 
+local limit_parser
+local function parse_string_limit(lim)
+  local function parse_time_suffix(s)
+    if s == 's' then
+      return 1
+    elseif s == 'm' then
+      return 60
+    elseif s == 'h' then
+      return 3600
+    elseif s == 'd' then
+      return 86400
+    end
+  end
+  local function parse_num_suffix(s)
+    if s == '' then
+      return 1
+    elseif s == 'k' then
+      return 1000
+    elseif s == 'm' then
+      return 1000000
+    elseif s == 'g' then
+      return 1000000000
+    end
+  end
+  local lpeg = require "lpeg"
+
+  if not limit_parser then
+    local digit = lpeg.R("09")
+    limit_parser = {}
+    limit_parser.integer =
+    (lpeg.S("+-") ^ -1) *
+            (digit   ^  1)
+    limit_parser.fractional =
+    (lpeg.P(".")   ) *
+            (digit ^ 1)
+    limit_parser.number =
+    (limit_parser.integer *
+            (limit_parser.fractional ^ -1)) +
+            (lpeg.S("+-") * limit_parser.fractional)
+    limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
+            (limit_parser.number / tonumber) *
+            ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
+      function (acc, val) return acc * val end)
+    limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
+            (limit_parser.number / tonumber) *
+            ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
+      function (acc, val) return acc * val end)
+    limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
+            (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
+            limit_parser.time)
+  end
+  local t = lpeg.match(limit_parser.limit, lim)
+
+  if t and t[1] and t[2] and t[2] ~= 0 then
+    return t[1] / t[2], t[1]
+  end
+
+  rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
+
+  return nil
+end
+
 --- Parse atime and bucket of limit
 local function parse_limits(data)
   local function parse_limit_elt(str)
@@ -446,7 +508,15 @@ local function rate_test_set(task, func)
             table.insert(args, {settings[k], rk})
           elseif type(settings[k]) == 'string' and
               (custom_keywords[settings[k]] and type(custom_keywords[settings[k]]['get_limit']) == 'function') then
-            table.insert(args, {custom_keywords[settings[k]]['get_limit'](), rate_key})
+            local res = custom_keywords[settings[k]]['get_limit']()
+            if type(res) == 'table' then
+              table.insert(args, {res, rate_key})
+            elseif type(res) == 'string' then
+              local plim, size = parse_string_limit(res)
+              if plim then
+                table.insert(args, {{size, plim, 1}, rate_key})
+              end
+            end
           end
         end
       else
@@ -454,7 +524,15 @@ local function rate_test_set(task, func)
           table.insert(args, {settings[k], rate_key})
         elseif type(settings[k]) == 'string' and
             (custom_keywords[settings[k]] and type(custom_keywords[settings[k]]['get_limit']) == 'function') then
-          table.insert(args, {custom_keywords[settings[k]]['get_limit'](), rate_key})
+          local res = custom_keywords[settings[k]]['get_limit']()
+          if type(res) == 'table' then
+            table.insert(args, {res, rate_key})
+          elseif type(res) == 'string' then
+            local plim, size = parse_string_limit(res)
+            if plim then
+              table.insert(args, {{size, plim, 1}, rate_key})
+            end
+          end
         end
       end
     end
@@ -504,68 +582,6 @@ local function parse_limit(str)
   end
 end
 
-local limit_parser
-local function parse_string_limit(lim)
-  local function parse_time_suffix(s)
-    if s == 's' then
-      return 1
-    elseif s == 'm' then
-      return 60
-    elseif s == 'h' then
-      return 3600
-    elseif s == 'd' then
-      return 86400
-    end
-  end
-  local function parse_num_suffix(s)
-    if s == '' then
-      return 1
-    elseif s == 'k' then
-      return 1000
-    elseif s == 'm' then
-      return 1000000
-    elseif s == 'g' then
-      return 1000000000
-    end
-  end
-  local lpeg = require "lpeg"
-
-  if not limit_parser then
-    local digit = lpeg.R("09")
-    limit_parser = {}
-    limit_parser.integer =
-    (lpeg.S("+-") ^ -1) *
-            (digit   ^  1)
-    limit_parser.fractional =
-    (lpeg.P(".")   ) *
-            (digit ^ 1)
-    limit_parser.number =
-    (limit_parser.integer *
-            (limit_parser.fractional ^ -1)) +
-            (lpeg.S("+-") * limit_parser.fractional)
-    limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
-            (limit_parser.number / tonumber) *
-            ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
-      function (acc, val) return acc * val end)
-    limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
-            (limit_parser.number / tonumber) *
-            ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
-      function (acc, val) return acc * val end)
-    limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
-            (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
-            limit_parser.time)
-  end
-  local t = lpeg.match(limit_parser.limit, lim)
-
-  if t and t[1] and t[2] and t[2] ~= 0 then
-    return t[1] / t[2], t[1]
-  end
-
-  rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
-
-  return nil
-end
-
 local opts = rspamd_config:get_all_opt('ratelimit')
 if opts then
   local rates = opts['limit']
@@ -604,7 +620,7 @@ if opts then
         (type(lim) == 'table' and type(lim[1]) == 'number' and lim[1] > 0)
         or (type(lim) == 'table' and (lim[3]))
   end, settings)))
-  rspamd_logger.infox(rspamd_config, 'enabled rate buckets: %s', enabled_limits)
+  rspamd_logger.infox(rspamd_config, 'enabled rate buckets: [%1]', table.concat(enabled_limits, ','))
 
   if opts['whitelisted_rcpts'] and type(opts['whitelisted_rcpts']) == 'string' then
     whitelisted_rcpts = rspamd_str_split(opts['whitelisted_rcpts'], ',')