]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Ratelimit: support fetching limits from Redis 1818/head
authorAndrew Lewis <nerf@judo.za.org>
Tue, 29 Aug 2017 11:07:13 +0000 (13:07 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Tue, 29 Aug 2017 12:29:15 +0000 (14:29 +0200)
src/plugins/lua/ratelimit.lua

index 2516d18444d206498ea6a88fd32e6db3f571a8f0..b05ed0114f7da91ebbbf85479f0a89fa536560e8 100644 (file)
@@ -375,7 +375,86 @@ local function dynamic_rate_key(task, rtype)
   end
 end
 
-local function get_buckets(task)
+local function process_buckets(task, buckets)
+  if not buckets then return end
+  local function rl_redis_cb(err, data)
+    if err then
+      rspamd_logger.infox(task, 'got error while setting limit: %1', err)
+    end
+    if not data then return end
+    if data[1] == 1 then
+      rspamd_logger.infox(task,
+        'ratelimit "%s" exceeded',
+        data[2])
+      task:set_pre_result('soft reject',
+        message_func(task, data[2]))
+    end
+  end
+  local function rl_symbol_redis_cb(err, data)
+    if err then
+      rspamd_logger.infox(task, 'got error while setting limit: %1', err)
+    end
+    if not data then return end
+    for i, b in ipairs(data) do
+      task:insert_result(ratelimit_symbol, b[2], string.format('%s:%s:%s', i, b[1], b[2]))
+    end
+  end
+  local redis_cb = rl_redis_cb
+  if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end
+  local args = {redis_script_sha, #buckets}
+  for _, bucket in ipairs(buckets) do
+    table.insert(args, bucket[2])
+  end
+  for _, bucket in ipairs(buckets) do
+    if use_ip_score then
+      local asn_score,total_asn,
+        country_score,total_country,
+        ipnet_score,total_ipnet,
+        ip_score, total_ip = task:get_mempool():get_variable('ip_score',
+        'double,double,double,double,double,double,double,double')
+      local key_keywords = rspamd_str_split(bucket[2], '_')
+      local has_asn, has_ip = false, false
+      for _, v in ipairs(key_keywords) do
+        if v == "asn" then has_asn = true end
+        if v == "ip" then has_ip = true end
+        if has_ip and has_asn then break end
+      end
+      if has_asn and not has_ip then
+        bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
+      elseif has_ip then
+        if total_ip and total_ip > ip_score_lower_bound then
+          bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
+        elseif total_ipnet and total_ipnet > ip_score_lower_bound then
+          bucket[1][2] = resize_element(ipnet_score, total_ipnet, bucket[1][2])
+        elseif total_asn and total_asn > ip_score_lower_bound then
+          bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
+        elseif total_country and total_country > ip_score_lower_bound then
+          bucket[1][2] = resize_element(country_score, total_country, bucket[1][2])
+        else
+          bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
+        end
+      end
+    end
+    table.insert(args, bucket[1][1])
+    table.insert(args, bucket[1][2])
+  end
+  table.insert(args, rspamd_util.get_time())
+  table.insert(args, task:get_queue_id() or task:get_uid())
+  local ret = rspamd_redis_make_request(task,
+    redis_params, -- connect params
+    nil, -- hash key
+    true, -- is write
+    redis_cb, --callback
+    'evalsha', -- command
+    args -- arguments
+  )
+  if not ret then
+    rspamd_logger.errx(task, 'got error connecting to redis')
+  end
+end
+
+local function ratelimit_cb(task)
+  if rspamd_lua_utils.is_rspamc_or_controller(task) then return end
   local args = {}
   -- Get initial task data
   local ip = task:get_from_ip()
@@ -412,6 +491,38 @@ local function get_buckets(task)
     end
   end
 
+  local redis_keys = {}
+  local redis_keys_rev = {}
+  local function collect_redis_keys()
+    local function collect_cb(err, data)
+      if err then
+        rspamd_logger.errx(task, 'redis error: %1', err)
+      else
+        for i, d in ipairs(data) do
+          if type(d) == 'string' then
+            local plim, size = parse_string_limit(d)
+            if plim then
+              table.insert(args, {{plim, size}, redis_keys_rev[i]})
+            end
+          end
+        end
+        return process_buckets(task, args)
+      end
+    end
+    local requested_keys = rspamd_redis_make_request(task,
+      redis_params, -- connect params
+      nil, -- hash key
+      true, -- is write
+      collect_cb, --callback
+      'MGET', -- command
+      redis_keys -- arguments
+    )
+    if not requested_keys then
+      rspamd_logger.errx(task, 'got error connecting to redis')
+      return process_buckets(task, args)
+    end
+  end
+
   local rate_key
   for k in pairs(settings) do
     rate_key = dynamic_rate_key(task, k)
@@ -426,6 +537,14 @@ local function get_buckets(task)
               local plim, size = parse_string_limit(r)
               if plim then
                 table.insert(args, {{plim, size}, rk})
+              else
+                local rkey = string.match(settings[k], 'redis:(.*)')
+                if rkey then
+                  table.insert(redis_keys, rkey)
+                  redis_keys_rev[#redis_keys] = rk
+                else
+                  rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k])
+                end
               end
             end
           end
@@ -439,97 +558,37 @@ local function get_buckets(task)
             local plim, size = parse_string_limit(r)
             if plim then
               table.insert(args, {{plim, size}, rate_key})
+            else
+              local rkey = string.match(settings[k], 'redis:(.*)')
+              if rkey then
+                table.insert(redis_keys, rkey)
+                redis_keys_rev[#redis_keys] = rate_key
+              else
+                rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k])
+              end
             end
           end
         elseif type(settings[k]) == 'table' then
           for _, rl in ipairs(settings[k]) do
             table.insert(args, {{rl[1], rl[2]}, rate_key})
           end
+        elseif type(settings[k]) == 'string' then
+          local rkey = string.match(settings[k], 'redis:(.*)')
+          if rkey then
+            table.insert(redis_keys, rkey)
+            redis_keys_rev[#redis_keys] = rate_key
+          else
+            rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k])
+          end
         end
       end
     end
   end
 
-  return args
-end
-
-local function ratelimit_cb(task)
-  if rspamd_lua_utils.is_rspamc_or_controller(task) then return end
-  local function rl_redis_cb(err, data)
-    if err then
-      rspamd_logger.infox(task, 'got error while setting limit: %1', err)
-    end
-    if not data then return end
-    if data[1] == 1 then
-      rspamd_logger.infox(task,
-        'ratelimit "%s" exceeded',
-        data[2])
-      task:set_pre_result('soft reject',
-        message_func(task, data[2]))
-    end
-  end
-  local function rl_symbol_redis_cb(err, data)
-    if err then
-      rspamd_logger.infox(task, 'got error while setting limit: %1', err)
-    end
-    if not data then return end
-    for i, b in ipairs(data) do
-      task:insert_result(ratelimit_symbol, b[2], string.format('%s:%s:%s', i, b[1], b[2]))
-    end
-  end
-  local redis_cb = rl_redis_cb
-  if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end
-  local buckets = get_buckets(task)
-  if not buckets then return end
-  local args = {redis_script_sha, #buckets}
-  for _, bucket in ipairs(buckets) do
-    table.insert(args, bucket[2])
-  end
-  for _, bucket in ipairs(buckets) do
-    if use_ip_score then
-      local asn_score,total_asn,
-        country_score,total_country,
-        ipnet_score,total_ipnet,
-        ip_score, total_ip = task:get_mempool():get_variable('ip_score',
-        'double,double,double,double,double,double,double,double')
-      local key_keywords = rspamd_str_split(bucket[2], '_')
-      local has_asn, has_ip = false, false
-      for _, v in ipairs(key_keywords) do
-        if v == "asn" then has_asn = true end
-        if v == "ip" then has_ip = true end
-        if has_ip and has_asn then break end
-      end
-      if has_asn and not has_ip then
-        bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
-      elseif has_ip then
-        if total_ip and total_ip > ip_score_lower_bound then
-          bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
-        elseif total_ipnet and total_ipnet > ip_score_lower_bound then
-          bucket[1][2] = resize_element(ipnet_score, total_ipnet, bucket[1][2])
-        elseif total_asn and total_asn > ip_score_lower_bound then
-          bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
-        elseif total_country and total_country > ip_score_lower_bound then
-          bucket[1][2] = resize_element(country_score, total_country, bucket[1][2])
-        else
-          bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
-        end
-      end
-    end
-    table.insert(args, bucket[1][1])
-    table.insert(args, bucket[1][2])
-  end
-  table.insert(args, rspamd_util.get_time())
-  table.insert(args, task:get_queue_id() or task:get_uid())
-  local ret = rspamd_redis_make_request(task,
-    redis_params, -- connect params
-    nil, -- hash key
-    true, -- is write
-    redis_cb, --callback
-    'evalsha', -- command
-    args -- arguments
-  )
-  if not ret then
-    rspamd_logger.errx(task, 'got error connecting to redis')
+  if redis_keys[1] then
+    return collect_redis_keys()
+  else
+    return process_buckets(task, args)
   end
 end