]> source.dussan.org Git - rspamd.git/commitdiff
Rework and optimize ratelimit plugin.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 25 Feb 2015 17:45:38 +0000 (17:45 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 25 Feb 2015 17:46:12 +0000 (17:46 +0000)
src/plugins/lua/ratelimit.lua

index e1601282d8c2256d58600f02c6c764d7f12b57a5..f1414106151aa5e393451c9b59803cfd5efd6225 100644 (file)
@@ -60,20 +60,39 @@ local upstream_list = require "rspamd_upstream_list"
 local _ = require "fun"
 
 --- Parse atime and bucket of limit
-local function parse_limit_data(str)
-  local pos,_ = string.find(str, ':')
-  if not pos then
-    return 0, 0
-  else
-    local atime = tonumber(string.sub(str, 1, pos - 1))
-    local bucket = tonumber(string.sub(str, pos + 1))
-    return atime,bucket
+local function parse_limits(data)
+  local function parse_limit_elt(str)
+    local pos,_ = string.find(str, ':')
+    if not pos then
+      return {0, 0}
+    else
+      local atime = tonumber(string.sub(str, 1, pos - 1))
+      local bucket = tonumber(string.sub(str, pos + 1))
+      return {atime,bucket}
+    end
   end
+  
+  return _.map(function(e) 
+    if type(e) == 'string' then 
+      return parse_limit_elt
+    else
+      return {0, 0}
+    end
+    end, data)
+end
+
+local function generate_format_string(args, is_set)
+  if is_set then
+    return _.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args)
+  end
+  return _.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args)
 end
 
 --- Check specific limit inside redis
-local function check_specific_limit (task, limit, key)
+local function check_limits(task, args)
 
+  local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args)
+  print(key)
   local upstream = upstreams:get_upstream_by_hash(key)
   local addr = upstream:get_addr()
   --- Called when value was set on server
@@ -88,38 +107,38 @@ local function check_specific_limit (task, limit, key)
   --- Called when value is got from server
   local function rate_get_cb(task, err, data)
     if data then
-      local atime, bucket = parse_limit_data(data)
       local tv = task:get_timeval()
       local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
-      -- Leak messages
-      bucket = bucket - limit[2] * (ntime - atime);
-      if bucket > 0 then
-        local lstr = string.format('%.7f:%.7f', ntime, bucket)
-        rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
-          'SET %b %b', key, lstr)
-        if bucket > limit[1] then
-          task:set_pre_result('soft reject', 'Ratelimit exceeded: ' .. key)
+      local it = _.zip(_.map(function(a) return a[1] end, args), parse_limits(data))
+      
+      _.each(function(elt)
+        local bucket = elt[2][2]
+        local limit = elt[1]
+        local atime = elt[2][1]
+        
+        bucket = bucket - limit[2] * (ntime - atime);
+        if bucket > 0 then
+          if bucket > limit[1] then
+            task:set_pre_result('soft reject', 'Ratelimit exceeded')
+          end
         end
-      else
-        rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
-          'DEL %b', key)
-      end
-    end
-    if err then
+      end, it)
+    elseif err then
       rspamd_logger.info('got error while getting limit: ' .. err)
       upstream:fail()
     end
   end
+  
   if upstream then
-    rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_get_cb, 'GET %b', key)
+    local cmd = generate_format_string(args, false)
+    
+    rspamd_redis.make_request(task, addr, rate_get_cb, cmd, 
+      _.totable(_.map(function(l) return l[2] end, args)))
   end
 end
 
 --- Set specific limit inside redis
-local function set_specific_limit (task, limit, key)
-  local upstream = upstreams:get_upstream_by_hash(key)
-  local addr = upstream:get_addr()
-  --- Called when value was set on server
+local function set_limits(task, args)
   local function rate_set_key_cb(task, err, data)
     if err then
       rspamd_logger.info('got error while setting limit: ' .. err)
@@ -128,22 +147,32 @@ local function set_specific_limit (task, limit, key)
       upstream:ok()
     end
   end
-  --- Called when value is got from server
+  local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args)
+  local upstream = upstreams:get_upstream_by_hash(key)
+  local addr = upstream:get_addr()
   local function rate_set_cb(task, err, data)
-    if not err and not data then
-      --- Add new entry
-      local tv = task:get_timeval()
-      local atime = tv['tv_usec'] / 1000000. + tv['tv_sec']
-      local lstr = string.format('%.7f:1', atime)
-      rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
-        'SET %b %b', key, lstr)
-    elseif data then
-      local atime, bucket = parse_limit_data(data)
+    if data then
       local tv = task:get_timeval()
       local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
-      -- Leak messages
-      bucket = bucket - limit[2] * (ntime - atime) + 1;
-      local lstr = string.format('%.7f:%.7f', ntime, bucket)
+      local it = _.zip(args, parse_limits(data))
+      local values = {}
+      _.each(function(elt)
+        local bucket = elt[2][2]
+        local limit = elt[1][1]
+        local atime = elt[2][1]
+        
+        if bucket > 0 then
+          bucket = bucket - limit[2] * (ntime - atime) + 1;
+        else
+          bucket = 1
+        end
+        local lstr = string.format('%.7f:%.7f', ntime, bucket)
+        table.insert(values, elt[1][2], lstr)
+      end, it)
+      
+      local cmd = generate_format_string(values, true)
+      rspamd_redis.make_request(task, addr, rate_set_key_cb, cmd, values)
       rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
         'SET %b %b', key, lstr)
     elseif err then
@@ -152,7 +181,10 @@ local function set_specific_limit (task, limit, key)
     end
   end
   if upstream then
-    rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_cb, 'GET %b', key)
+    local cmd = generate_format_string(args, false)
+    
+    rspamd_redis.make_request(task, addr, rate_set_cb, cmd,
+      _.totable(_.map(function(l) return l[2] end, args)))
   end
 end
 
@@ -173,21 +205,18 @@ end
 
 --- Check whether this addr is bounce
 local function check_bounce(from)
-  for _,b in ipairs(whitelisted_rcpts) do
-    if b == from then
-      return true
-    end
-  end
-  return false
+  return _.any(function(b) return b == from end, bounce_senders)
 end
 
 --- Check or update ratelimit
 local function rate_test_set(task, func)
+  local args = {}
   -- Get initial task data
   local ip = task:get_from_ip()
   if ip and ip:is_valid() and whitelisted_ip then
     if whitelisted_ip:get_key(ip) then
       -- Do not check whitelisted ip
+      rspamd_logger.info('skip ratelimit for whitelisted IP')
       return
     end
   end
@@ -195,14 +224,14 @@ local function rate_test_set(task, func)
   local rcpts = task:get_recipients()
   local rcpts_user = {}
   if rcpts then
-    if table.maxn(rcpts) > max_rcpt then
-      rspamd_logger.info(string.format('message <%s> contains %d recipients, maximum is %d',
-        task:get_message_id(), table.maxn(rcpts), max_rcpt))
+    _.each(function(r) table.insert(rcpts_user, r['user']) end, rcpts)
+    if _.any(function(r) 
+      _.any(function(w) return r == w end, whitelisted_rcpts) end, 
+      rcpts_user) then
+      
+      rspamd_logger.info('skip ratelimit for whitelisted recipient')
       return
     end
-    for i,r in ipairs(rcpts) do
-      rcpts_user[i] = r['user']
-    end
   end
   -- Parse from
   local from = task:get_from()
@@ -215,39 +244,37 @@ local function rate_test_set(task, func)
   -- Get user (authuser)
   local auser = task:get_user()
   if auser then
-    func(task, settings['user'], make_rate_key (auser, '<auth>', nil))
-  end
-
-  if not rcpts_user[1] then
-    -- Nothing to check
-    return
+    table.insert(args, {settings['user'], make_rate_key (auser, '<auth>', nil)})
   end
 
   local is_bounce = check_bounce(from_user)
 
-  for _,r in ipairs(rcpts) do
-    if is_bounce then
-      -- Bounce specific limit
-      func(task, settings['bounce_to'], make_rate_key ('<>', r['addr'], nil))
+  if rcpts then
+    _.each(function(r)
+      if is_bounce then
+        table.insert(args, {settings['bounce_to'], make_rate_key ('<>', r['addr'], nil)})
+        if ip then
+          table.insert(args, {settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip)})
+        end
+      end
+      table.insert(args, {settings['to'], make_rate_key (nil, r['addr'], nil)})
       if ip then
-        func(task, settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip))
+        table.insert(args, {settings['to_ip'], make_rate_key (nil, r['addr'], ip)})
+        table.insert(args, {settings['to_ip_from'], make_rate_key (from_addr, r['addr'], ip)})
       end
-    end
-    func(task, settings['to'], make_rate_key (nil, r['addr'], nil))
-    if ip then
-      func(task, settings['to_ip'], make_rate_key (nil, r['addr'], ip))
-      func(task, settings['to_ip_from'], make_rate_key (from_addr, r['addr'], ip))
-    end
+    end, rcpts)
   end
+  
+  func(task, args)
 end
 
 --- Check limit
 local function rate_test(task)
-  rate_test_set(task, check_specific_limit)
+  rate_test_set(task, check_limits)
 end
 --- Update limit
 local function rate_set(task)
-  rate_test_set(task, set_specific_limit)
+  rate_test_set(task, set_limits)
 end