diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2015-02-25 17:45:38 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2015-02-25 17:46:12 +0000 |
commit | f185ab283bf7dcddd50cc6fccdea5cc7c0808627 (patch) | |
tree | 5503eded0b6de67918eae98affd591b6c994657d /src/plugins | |
parent | 61c20a44fe6763b1f2012c63847d5690b26706ba (diff) | |
download | rspamd-f185ab283bf7dcddd50cc6fccdea5cc7c0808627.tar.gz rspamd-f185ab283bf7dcddd50cc6fccdea5cc7c0808627.zip |
Rework and optimize ratelimit plugin.
Diffstat (limited to 'src/plugins')
-rw-r--r-- | src/plugins/lua/ratelimit.lua | 175 |
1 files changed, 101 insertions, 74 deletions
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index e1601282d..f14141061 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -60,20 +60,39 @@ local upstream_list = require "rspamd_upstream_list" local _ = require "fun" --- Parse atime and bucket of limit -local function parse_limit_data(str) - local pos,_ = string.find(str, ':') - if not pos then - return 0, 0 - else - local atime = tonumber(string.sub(str, 1, pos - 1)) - local bucket = tonumber(string.sub(str, pos + 1)) - return atime,bucket +local function parse_limits(data) + local function parse_limit_elt(str) + local pos,_ = string.find(str, ':') + if not pos then + return {0, 0} + else + local atime = tonumber(string.sub(str, 1, pos - 1)) + local bucket = tonumber(string.sub(str, pos + 1)) + return {atime,bucket} + end end + + return _.map(function(e) + if type(e) == 'string' then + return parse_limit_elt + else + return {0, 0} + end + end, data) +end + +local function generate_format_string(args, is_set) + if is_set then + return _.foldl(function(acc, k) return acc .. ' %s %s' end, 'MSET', args) + end + return _.foldl(function(acc, k) return acc .. ' %s' end, 'MGET', args) end --- Check specific limit inside redis -local function check_specific_limit (task, limit, key) +local function check_limits(task, args) + local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args) + print(key) local upstream = upstreams:get_upstream_by_hash(key) local addr = upstream:get_addr() --- Called when value was set on server @@ -88,38 +107,38 @@ local function check_specific_limit (task, limit, key) --- Called when value is got from server local function rate_get_cb(task, err, data) if data then - local atime, bucket = parse_limit_data(data) local tv = task:get_timeval() local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec'] - -- Leak messages - bucket = bucket - limit[2] * (ntime - atime); - if bucket > 0 then - local lstr = string.format('%.7f:%.7f', ntime, bucket) - rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, - 'SET %b %b', key, lstr) - if bucket > limit[1] then - task:set_pre_result('soft reject', 'Ratelimit exceeded: ' .. key) + local it = _.zip(_.map(function(a) return a[1] end, args), parse_limits(data)) + + _.each(function(elt) + local bucket = elt[2][2] + local limit = elt[1] + local atime = elt[2][1] + + bucket = bucket - limit[2] * (ntime - atime); + if bucket > 0 then + if bucket > limit[1] then + task:set_pre_result('soft reject', 'Ratelimit exceeded') + end end - else - rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, - 'DEL %b', key) - end - end - if err then + end, it) + elseif err then rspamd_logger.info('got error while getting limit: ' .. err) upstream:fail() end end + if upstream then - rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_get_cb, 'GET %b', key) + local cmd = generate_format_string(args, false) + + rspamd_redis.make_request(task, addr, rate_get_cb, cmd, + _.totable(_.map(function(l) return l[2] end, args))) end end --- Set specific limit inside redis -local function set_specific_limit (task, limit, key) - local upstream = upstreams:get_upstream_by_hash(key) - local addr = upstream:get_addr() - --- Called when value was set on server +local function set_limits(task, args) local function rate_set_key_cb(task, err, data) if err then rspamd_logger.info('got error while setting limit: ' .. err) @@ -128,22 +147,32 @@ local function set_specific_limit (task, limit, key) upstream:ok() end end - --- Called when value is got from server + local key = _.foldl(function(acc, k) return acc .. k[2] end, '', args) + local upstream = upstreams:get_upstream_by_hash(key) + local addr = upstream:get_addr() + local function rate_set_cb(task, err, data) - if not err and not data then - --- Add new entry - local tv = task:get_timeval() - local atime = tv['tv_usec'] / 1000000. + tv['tv_sec'] - local lstr = string.format('%.7f:1', atime) - rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, - 'SET %b %b', key, lstr) - elseif data then - local atime, bucket = parse_limit_data(data) + if data then local tv = task:get_timeval() local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec'] - -- Leak messages - bucket = bucket - limit[2] * (ntime - atime) + 1; - local lstr = string.format('%.7f:%.7f', ntime, bucket) + local it = _.zip(args, parse_limits(data)) + local values = {} + _.each(function(elt) + local bucket = elt[2][2] + local limit = elt[1][1] + local atime = elt[2][1] + + if bucket > 0 then + bucket = bucket - limit[2] * (ntime - atime) + 1; + else + bucket = 1 + end + local lstr = string.format('%.7f:%.7f', ntime, bucket) + table.insert(values, elt[1][2], lstr) + end, it) + + local cmd = generate_format_string(values, true) + rspamd_redis.make_request(task, addr, rate_set_key_cb, cmd, values) rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 'SET %b %b', key, lstr) elseif err then @@ -152,7 +181,10 @@ local function set_specific_limit (task, limit, key) end end if upstream then - rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_cb, 'GET %b', key) + local cmd = generate_format_string(args, false) + + rspamd_redis.make_request(task, addr, rate_set_cb, cmd, + _.totable(_.map(function(l) return l[2] end, args))) end end @@ -173,21 +205,18 @@ end --- Check whether this addr is bounce local function check_bounce(from) - for _,b in ipairs(whitelisted_rcpts) do - if b == from then - return true - end - end - return false + return _.any(function(b) return b == from end, bounce_senders) end --- Check or update ratelimit local function rate_test_set(task, func) + local args = {} -- Get initial task data local ip = task:get_from_ip() if ip and ip:is_valid() and whitelisted_ip then if whitelisted_ip:get_key(ip) then -- Do not check whitelisted ip + rspamd_logger.info('skip ratelimit for whitelisted IP') return end end @@ -195,14 +224,14 @@ local function rate_test_set(task, func) local rcpts = task:get_recipients() local rcpts_user = {} if rcpts then - if table.maxn(rcpts) > max_rcpt then - rspamd_logger.info(string.format('message <%s> contains %d recipients, maximum is %d', - task:get_message_id(), table.maxn(rcpts), max_rcpt)) + _.each(function(r) table.insert(rcpts_user, r['user']) end, rcpts) + if _.any(function(r) + _.any(function(w) return r == w end, whitelisted_rcpts) end, + rcpts_user) then + + rspamd_logger.info('skip ratelimit for whitelisted recipient') return end - for i,r in ipairs(rcpts) do - rcpts_user[i] = r['user'] - end end -- Parse from local from = task:get_from() @@ -215,39 +244,37 @@ local function rate_test_set(task, func) -- Get user (authuser) local auser = task:get_user() if auser then - func(task, settings['user'], make_rate_key (auser, '<auth>', nil)) - end - - if not rcpts_user[1] then - -- Nothing to check - return + table.insert(args, {settings['user'], make_rate_key (auser, '<auth>', nil)}) end local is_bounce = check_bounce(from_user) - for _,r in ipairs(rcpts) do - if is_bounce then - -- Bounce specific limit - func(task, settings['bounce_to'], make_rate_key ('<>', r['addr'], nil)) + if rcpts then + _.each(function(r) + if is_bounce then + table.insert(args, {settings['bounce_to'], make_rate_key ('<>', r['addr'], nil)}) + if ip then + table.insert(args, {settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip)}) + end + end + table.insert(args, {settings['to'], make_rate_key (nil, r['addr'], nil)}) if ip then - func(task, settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip)) + table.insert(args, {settings['to_ip'], make_rate_key (nil, r['addr'], ip)}) + table.insert(args, {settings['to_ip_from'], make_rate_key (from_addr, r['addr'], ip)}) end - end - func(task, settings['to'], make_rate_key (nil, r['addr'], nil)) - if ip then - func(task, settings['to_ip'], make_rate_key (nil, r['addr'], ip)) - func(task, settings['to_ip_from'], make_rate_key (from_addr, r['addr'], ip)) - end + end, rcpts) end + + func(task, args) end --- Check limit local function rate_test(task) - rate_test_set(task, check_specific_limit) + rate_test_set(task, check_limits) end --- Update limit local function rate_set(task) - rate_test_set(task, set_specific_limit) + rate_test_set(task, set_limits) end |