summaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2015-02-25 17:45:38 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2015-02-25 17:46:12 +0000
commitf185ab283bf7dcddd50cc6fccdea5cc7c0808627 (patch)
tree5503eded0b6de67918eae98affd591b6c994657d /src/plugins
parent61c20a44fe6763b1f2012c63847d5690b26706ba (diff)
downloadrspamd-f185ab283bf7dcddd50cc6fccdea5cc7c0808627.tar.gz
rspamd-f185ab283bf7dcddd50cc6fccdea5cc7c0808627.zip
Rework and optimize ratelimit plugin.
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/ratelimit.lua175
1 files changed, 101 insertions, 74 deletions
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
index e1601282d..f14141061 100644
--- a/src/plugins/lua/ratelimit.lua
+++ b/src/plugins/lua/ratelimit.lua
@@ -60,20 +60,39 @@ local upstream_list = require "rspamd_upstream_list"
local _ = require "fun"
--- Parse atime and bucket of limit
-local function parse_limit_data(str)
- local pos,_ = string.find(str, ':')
- if not pos then
- return 0, 0
- else
- local atime = tonumber(string.sub(str, 1, pos - 1))
- local bucket = tonumber(string.sub(str, pos + 1))
- return atime,bucket
+local function parse_limits(data)
+ local function parse_limit_elt(str)
+ local pos,_ = string.find(str, ':')
+ if not pos then
+ return {0, 0}
+ else
+ local atime = tonumber(string.sub(str, 1, pos - 1))
+ local bucket = tonumber(string.sub(str, pos + 1))
+ return {atime,bucket}
+ end
end
+
+ return _.map(function(e)
+ if type(e) == 'string' then
+ return parse_limit_elt
+ else
+ return {0, 0}
+ end
+ end, data)
+end
+
+local function generate_format_string(args, is_set)
+ if is_set then
+ return _.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args)
+ end
+ return _.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args)
end
--- Check specific limit inside redis
-local function check_specific_limit (task, limit, key)
+local function check_limits(task, args)
+ local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args)
+ print(key)
local upstream = upstreams:get_upstream_by_hash(key)
local addr = upstream:get_addr()
--- Called when value was set on server
@@ -88,38 +107,38 @@ local function check_specific_limit (task, limit, key)
--- Called when value is got from server
local function rate_get_cb(task, err, data)
if data then
- local atime, bucket = parse_limit_data(data)
local tv = task:get_timeval()
local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
- -- Leak messages
- bucket = bucket - limit[2] * (ntime - atime);
- if bucket > 0 then
- local lstr = string.format('%.7f:%.7f', ntime, bucket)
- rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
- 'SET %b %b', key, lstr)
- if bucket > limit[1] then
- task:set_pre_result('soft reject', 'Ratelimit exceeded: ' .. key)
+ local it = _.zip(_.map(function(a) return a[1] end, args), parse_limits(data))
+
+ _.each(function(elt)
+ local bucket = elt[2][2]
+ local limit = elt[1]
+ local atime = elt[2][1]
+
+ bucket = bucket - limit[2] * (ntime - atime);
+ if bucket > 0 then
+ if bucket > limit[1] then
+ task:set_pre_result('soft reject', 'Ratelimit exceeded')
+ end
end
- else
- rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
- 'DEL %b', key)
- end
- end
- if err then
+ end, it)
+ elseif err then
rspamd_logger.info('got error while getting limit: ' .. err)
upstream:fail()
end
end
+
if upstream then
- rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_get_cb, 'GET %b', key)
+ local cmd = generate_format_string(args, false)
+
+ rspamd_redis.make_request(task, addr, rate_get_cb, cmd,
+ _.totable(_.map(function(l) return l[2] end, args)))
end
end
--- Set specific limit inside redis
-local function set_specific_limit (task, limit, key)
- local upstream = upstreams:get_upstream_by_hash(key)
- local addr = upstream:get_addr()
- --- Called when value was set on server
+local function set_limits(task, args)
local function rate_set_key_cb(task, err, data)
if err then
rspamd_logger.info('got error while setting limit: ' .. err)
@@ -128,22 +147,32 @@ local function set_specific_limit (task, limit, key)
upstream:ok()
end
end
- --- Called when value is got from server
+ local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args)
+ local upstream = upstreams:get_upstream_by_hash(key)
+ local addr = upstream:get_addr()
+
local function rate_set_cb(task, err, data)
- if not err and not data then
- --- Add new entry
- local tv = task:get_timeval()
- local atime = tv['tv_usec'] / 1000000. + tv['tv_sec']
- local lstr = string.format('%.7f:1', atime)
- rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
- 'SET %b %b', key, lstr)
- elseif data then
- local atime, bucket = parse_limit_data(data)
+ if data then
local tv = task:get_timeval()
local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
- -- Leak messages
- bucket = bucket - limit[2] * (ntime - atime) + 1;
- local lstr = string.format('%.7f:%.7f', ntime, bucket)
+ local it = _.zip(args, parse_limits(data))
+ local values = {}
+ _.each(function(elt)
+ local bucket = elt[2][2]
+ local limit = elt[1][1]
+ local atime = elt[2][1]
+
+ if bucket > 0 then
+ bucket = bucket - limit[2] * (ntime - atime) + 1;
+ else
+ bucket = 1
+ end
+ local lstr = string.format('%.7f:%.7f', ntime, bucket)
+ table.insert(values, elt[1][2], lstr)
+ end, it)
+
+ local cmd = generate_format_string(values, true)
+ rspamd_redis.make_request(task, addr, rate_set_key_cb, cmd, values)
rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
'SET %b %b', key, lstr)
elseif err then
@@ -152,7 +181,10 @@ local function set_specific_limit (task, limit, key)
end
end
if upstream then
- rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_cb, 'GET %b', key)
+ local cmd = generate_format_string(args, false)
+
+ rspamd_redis.make_request(task, addr, rate_set_cb, cmd,
+ _.totable(_.map(function(l) return l[2] end, args)))
end
end
@@ -173,21 +205,18 @@ end
--- Check whether this addr is bounce
local function check_bounce(from)
- for _,b in ipairs(whitelisted_rcpts) do
- if b == from then
- return true
- end
- end
- return false
+ return _.any(function(b) return b == from end, bounce_senders)
end
--- Check or update ratelimit
local function rate_test_set(task, func)
+ local args = {}
-- Get initial task data
local ip = task:get_from_ip()
if ip and ip:is_valid() and whitelisted_ip then
if whitelisted_ip:get_key(ip) then
-- Do not check whitelisted ip
+ rspamd_logger.info('skip ratelimit for whitelisted IP')
return
end
end
@@ -195,14 +224,14 @@ local function rate_test_set(task, func)
local rcpts = task:get_recipients()
local rcpts_user = {}
if rcpts then
- if table.maxn(rcpts) > max_rcpt then
- rspamd_logger.info(string.format('message <%s> contains %d recipients, maximum is %d',
- task:get_message_id(), table.maxn(rcpts), max_rcpt))
+ _.each(function(r) table.insert(rcpts_user, r['user']) end, rcpts)
+ if _.any(function(r)
+ _.any(function(w) return r == w end, whitelisted_rcpts) end,
+ rcpts_user) then
+
+ rspamd_logger.info('skip ratelimit for whitelisted recipient')
return
end
- for i,r in ipairs(rcpts) do
- rcpts_user[i] = r['user']
- end
end
-- Parse from
local from = task:get_from()
@@ -215,39 +244,37 @@ local function rate_test_set(task, func)
-- Get user (authuser)
local auser = task:get_user()
if auser then
- func(task, settings['user'], make_rate_key (auser, '<auth>', nil))
- end
-
- if not rcpts_user[1] then
- -- Nothing to check
- return
+ table.insert(args, {settings['user'], make_rate_key (auser, '<auth>', nil)})
end
local is_bounce = check_bounce(from_user)
- for _,r in ipairs(rcpts) do
- if is_bounce then
- -- Bounce specific limit
- func(task, settings['bounce_to'], make_rate_key ('<>', r['addr'], nil))
+ if rcpts then
+ _.each(function(r)
+ if is_bounce then
+ table.insert(args, {settings['bounce_to'], make_rate_key ('<>', r['addr'], nil)})
+ if ip then
+ table.insert(args, {settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip)})
+ end
+ end
+ table.insert(args, {settings['to'], make_rate_key (nil, r['addr'], nil)})
if ip then
- func(task, settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip))
+ table.insert(args, {settings['to_ip'], make_rate_key (nil, r['addr'], ip)})
+ table.insert(args, {settings['to_ip_from'], make_rate_key (from_addr, r['addr'], ip)})
end
- end
- func(task, settings['to'], make_rate_key (nil, r['addr'], nil))
- if ip then
- func(task, settings['to_ip'], make_rate_key (nil, r['addr'], ip))
- func(task, settings['to_ip_from'], make_rate_key (from_addr, r['addr'], ip))
- end
+ end, rcpts)
end
+
+ func(task, args)
end
--- Check limit
local function rate_test(task)
- rate_test_set(task, check_specific_limit)
+ rate_test_set(task, check_limits)
end
--- Update limit
local function rate_set(task)
- rate_test_set(task, set_specific_limit)
+ rate_test_set(task, set_limits)
end