]> source.dussan.org Git - rspamd.git/commitdiff
Fix ratelimit plugin.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 26 Feb 2015 14:26:57 +0000 (14:26 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 26 Feb 2015 14:26:57 +0000 (14:26 +0000)
src/lua/lua_redis.c
src/plugins/lua/ratelimit.lua

index e2b19f8eecbeb14c84fe1b49afb166df307fef60..2d27fd6e1e412d9e0cf5e90c5e28668383a9d49a 100644 (file)
@@ -260,10 +260,10 @@ lua_redis_make_request (lua_State *L)
                                        args = g_alloca ((top + 1) * sizeof (gchar *));
                                        lua_pushnil (L);
                                        args[0] = cmd;
-                                       top = 0;
+                                       top = 1;
 
                                        while (lua_next (L, -2) != 0) {
-                                               args[++top] = lua_tostring (L, -1);
+                                               args[top++] = lua_tostring (L, -1);
                                                lua_pop (L, 1);
                                        }
 
index 80c74aa2639ec348e691e30ab83c0dd1fbbc2e00..0dd9b84049da89b753e0a1c0fb554accf06f8f32 100644 (file)
@@ -58,6 +58,7 @@ local rspamd_logger = require "rspamd_logger"
 local rspamd_redis = require "rspamd_redis"
 local upstream_list = require "rspamd_upstream_list"
 local _ = require "fun"
+--local dumper = require 'pl.pretty'.dump
 
 --- Parse atime and bucket of limit
 local function parse_limits(data)
@@ -72,27 +73,28 @@ local function parse_limits(data)
     end
   end
   
-  return _.map(function(e) 
+  return _.iter(data):map(function(e) 
     if type(e) == 'string' then 
-      return parse_limit_elt
+      return parse_limit_elt(e)
     else
       return {0, 0}
     end
-    end, data)
+    end):totable()
 end
 
 local function generate_format_string(args, is_set)
   if is_set then
-    return _.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args)
+    return 'MSET'
+    --return _.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args)
   end
-  return _.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args)
+  return 'MGET'
+  --return _.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args)
 end
 
 --- Check specific limit inside redis
 local function check_limits(task, args)
 
   local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args)
-  print(key)
   local upstream = upstreams:get_upstream_by_hash(key)
   local addr = upstream:get_addr()
   --- Called when value was set on server
@@ -109,20 +111,20 @@ local function check_limits(task, args)
     if data then
       local tv = task:get_timeval()
       local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
-      local it = _.zip(_.map(function(a) return a[1] end, args), parse_limits(data))
       
-      _.each(function(elt)
-        local bucket = elt[2][2]
-        local limit = elt[1]
-        local atime = elt[2][1]
+      _.each(function(elt, limit)
+        local bucket = elt[2]
+        local rate = limit[2]
+        local threshold = limit[1]
+        local atime = elt[1]
         
-        bucket = bucket - limit[2] * (ntime - atime);
+        bucket = bucket - rate * (ntime - atime);
         if bucket > 0 then
-          if bucket > limit[1] then
+          if bucket > threshold then
             task:set_pre_result('soft reject', 'Ratelimit exceeded')
           end
         end
-      end, it)
+      end, _.zip(parse_limits(data), _.map(function(a) return a[1] end, args)))
     elseif err then
       rspamd_logger.info('got error while getting limit: ' .. err)
       upstream:fail()
@@ -155,21 +157,22 @@ local function set_limits(task, args)
     if data then
       local tv = task:get_timeval()
       local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
-      local it = _.zip(args, parse_limits(data))
       local values = {}
-      _.each(function(elt)
-        local bucket = elt[2][2]
-        local limit = elt[1][1]
-        local atime = elt[2][1]
+      _.each(function(elt, limit)
+        local bucket = elt[2]
+        local rate = limit[1][2]
+        local threshold = limit[1][1]
+        local atime = elt[1]
         
         if bucket > 0 then
-          bucket = bucket - limit[2] * (ntime - atime) + 1;
+          bucket = bucket - rate * (ntime - atime) + 1;
         else
           bucket = 1
         end
-        local lstr = string.format('%.7f:%.7f', ntime, bucket)
-        table.insert(values, elt[1][2], lstr)
-      end, it)
+        local lstr = string.format('%.3f:%.3f', ntime, bucket)
+        table.insert(values, limit[2])
+        table.insert(values, lstr)
+      end, _.zip(parse_limits(data), _.iter(args)))
       
       local cmd = generate_format_string(values, true)
       rspamd_redis.make_request(task, addr, rate_set_key_cb, cmd, values)