]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Use new redis API in ratelimit plugin
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 22 Jun 2016 14:22:11 +0000 (15:22 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 22 Jun 2016 14:22:11 +0000 (15:22 +0100)
src/plugins/lua/ratelimit.lua

index 1032a35b9958cb7445edafc96e6288cce201a79c..4f8330cd3d6fb1e53d4ff95d969182e1d1a6503b 100644 (file)
@@ -39,7 +39,7 @@ local bounce_senders = {'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-da
 local whitelisted_rcpts = {'postmaster', 'mailer-daemon'}
 local whitelisted_ip
 local max_rcpt = 5
-local upstreams
+local redis_params
 local ratelimit_symbol
 -- Do not delay mail after 1 day
 local max_delay = 24 * 3600
@@ -48,7 +48,7 @@ local rspamd_logger = require "rspamd_logger"
 local rspamd_redis = require "rspamd_redis"
 local upstream_list = require "rspamd_upstream_list"
 local rspamd_util = require "rspamd_util"
-local _ = require "fun"
+local fun = require "fun"
 --local dumper = require 'pl.pretty'.dump
 
 --- Parse atime and bucket of limit
@@ -74,7 +74,7 @@ local function parse_limits(data)
     end
   end
 
-  return _.iter(data):map(function(e)
+  return fun.iter(data):map(function(e)
     if type(e) == 'string' then
       return parse_limit_elt(e)
     else
@@ -86,24 +86,23 @@ end
 local function generate_format_string(args, is_set)
   if is_set then
     return 'MSET'
-    --return _.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args)
+    --return fun.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args)
   end
   return 'MGET'
-  --return _.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args)
+  --return fun.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args)
 end
 
 --- Check specific limit inside redis
 local function check_limits(task, args)
 
-  local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args)
-  local upstream = upstreams:get_upstream_by_hash(key)
-  local addr = upstream:get_addr()
+  local key = fun.foldl(function(acc, k) return acc .. k[2] end, '', args)
+  local ret,upstream
   --- Called when value is got from server
   local function rate_get_cb(task, err, data)
     if data then
       local ntime = rspamd_util.get_time()
 
-      _.each(function(elt, limit)
+      fun.each(function(elt, limit)
         local bucket = elt[2]
         local rate = limit[2]
         local threshold = limit[1]
@@ -132,32 +131,34 @@ local function check_limits(task, args)
             end
           end
         end
-      end, _.zip(parse_limits(data), _.map(function(a) return a[1] end, args)))
+      end, fun.zip(parse_limits(data), fun.map(function(a) return a[1] end, args)))
     elseif err then
       rspamd_logger.infox(task, 'got error while getting limit: %1', err)
       upstream:fail()
     end
   end
 
-  if upstream then
-    local cmd = generate_format_string(args, false)
-
-    rspamd_redis.make_request(task, addr, rate_get_cb, cmd,
-      _.totable(_.map(function(l) return l[2] end, args)))
-  end
+  local cmd = generate_format_string(args, false)
+  ret,_,upstream = rspamd_redis_make_request(task,
+    redis_params, -- connect params
+    key, -- hash key
+    false, -- is write
+    rate_get_cb, --callback
+    cmd, -- command
+    fun.totable(fun.map(function(l) return l[2] end, args)) -- arguments
+  )
 end
 
 --- Set specific limit inside redis
 local function set_limits(task, args)
-  local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args)
-  local upstream = upstreams:get_upstream_by_hash(key)
-  local addr = upstream:get_addr()
+  local key = fun.foldl(function(acc, k) return acc .. k[2] end, '', args)
+  local ret, upstream
 
   local function rate_set_cb(task, err, data)
     if data then
       local ntime = rspamd_util.get_time()
       local values = {}
-      _.each(function(elt, limit)
+      fun.each(function(elt, limit)
         local bucket = elt[2]
         local rate = limit[1][2]
         local threshold = limit[1][1]
@@ -185,15 +186,15 @@ local function set_limits(task, args)
 
         local lstr = string.format('%.3f:%.3f:%.3f', ntime, bucket, ctime)
         table.insert(values, {limit[2], max_delay, lstr})
-      end, _.zip(parse_limits(data), _.iter(args)))
+      end, fun.zip(parse_limits(data), fun.iter(args)))
 
       local conn = rspamd_redis.connect({
         task = task,
-        host = addr
+        host = upstream:get_addr()
       })
 
       if conn then
-        _.each(function(v)
+        fun.each(function(v)
           conn:add_cmd('setex', v)
         end, values)
       else
@@ -205,12 +206,16 @@ local function set_limits(task, args)
       upstream:fail()
     end
   end
-  if upstream then
-    local cmd = generate_format_string(args, false)
 
-    rspamd_redis.make_request(task, addr, rate_set_cb, cmd,
-      _.totable(_.map(function(l) return l[2] end, args)))
-  end
+  local cmd = generate_format_string(args, false)
+  ret,_,upstream = rspamd_redis_make_request(task,
+    redis_params, -- connect params
+    key, -- hash key
+    false, -- is write
+    rate_set_cb, --callback
+    cmd, -- command
+    fun.totable(fun.map(function(l) return l[2] end, args)) -- arguments
+  )
 end
 
 --- Make rate key
@@ -230,7 +235,7 @@ end
 
 --- Check whether this addr is bounce
 local function check_bounce(from)
-  return _.any(function(b) return b == from end, bounce_senders)
+  return fun.any(function(b) return b == from end, bounce_senders)
 end
 
 --- Check or update ratelimit
@@ -249,9 +254,9 @@ local function rate_test_set(task, func)
   local rcpts = task:get_recipients()
   local rcpts_user = {}
   if rcpts then
-    _.each(function(r) table.insert(rcpts_user, r['user']) end, rcpts)
-    if _.any(function(r)
-      _.any(function(w) return r == w end, whitelisted_rcpts) end,
+    fun.each(function(r) table.insert(rcpts_user, r['user']) end, rcpts)
+    if fun.any(function(r)
+      fun.any(function(w) return r == w end, whitelisted_rcpts) end,
       rcpts_user) then
 
       rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')
@@ -275,7 +280,7 @@ local function rate_test_set(task, func)
   local is_bounce = check_bounce(from_user)
 
   if rcpts and not auser then
-    _.each(function(r)
+    fun.each(function(r)
       if is_bounce then
         if settings['bounce_to'][1] > 0 then
           table.insert(args, { settings['bounce_to'], make_rate_key('<>', r['addr'], nil) })
@@ -348,14 +353,14 @@ local opts =  rspamd_config:get_all_opt('ratelimit')
 if opts then
   local rates = opts['limit']
   if rates and type(rates) == 'table' then
-    _.each(parse_limit, rates)
+    fun.each(parse_limit, rates)
   elseif rates and type(rates) == 'string' then
     parse_limit(rates)
   end
 
   if opts['rates'] and type(opts['rates']) == 'table' then
     -- new way of setting limits
-    _.each(function(t, lim)
+    fun.each(function(t, lim)
       if type(lim) == 'table' and settings[t] then
         settings[t] = lim
       else
@@ -364,9 +369,9 @@ if opts then
     end, opts['rates'])
   end
 
-  local enabled_limits = _.totable(_.map(function(t, lim)
+  local enabled_limits = fun.totable(fun.map(function(t, lim)
     return t
-  end, _.filter(function(t, lim) return lim[1] > 0 end, settings)))
+  end, fun.filter(function(t, lim) return lim[1] > 0 end, settings)))
   rspamd_logger.infox(rspamd_config, 'enabled rate buckets: %s', enabled_limits)
 
   if opts['whitelisted_rcpts'] and type(opts['whitelisted_rcpts']) == 'string' then
@@ -392,8 +397,8 @@ if opts then
     max_rcpt = tonumber(opts['max_delay'])
   end
 
-  upstreams = rspamd_parse_redis_server('ratelimit')
-  if not upstreams then
+  redis_params = rspamd_parse_redis_server('ratelimit')
+  if not redis_params then
     rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')
   else
     if not ratelimit_symbol then