aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rambler-co.ru>2012-04-06 20:46:55 +0400
committerVsevolod Stakhov <vsevolod@rambler-co.ru>2012-04-06 20:46:55 +0400
commitefe165bc3d0f8b225be40bb8bd0bfebf7e972f04 (patch)
treea0c0b28fe9fecb53dd73281bea3efc337209ae07 /src
parente5d0c7f8f6cda246eddfcab82b056650be753fe7 (diff)
downloadrspamd-efe165bc3d0f8b225be40bb8bd0bfebf7e972f04.tar.gz
rspamd-efe165bc3d0f8b225be40bb8bd0bfebf7e972f04.zip
* Add ratelimit plugin
Some polishing of lua task api.
Diffstat (limited to 'src')
-rw-r--r--src/lua/lua_task.c24
-rw-r--r--src/lua/lua_upstream.c2
-rw-r--r--src/plugins/lua/ratelimit.lua315
3 files changed, 340 insertions, 1 deletions
diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c
index d205016e0..7ca1c58df 100644
--- a/src/lua/lua_task.c
+++ b/src/lua/lua_task.c
@@ -69,6 +69,7 @@ LUA_FUNCTION_DEF (task, get_helo);
LUA_FUNCTION_DEF (task, get_images);
LUA_FUNCTION_DEF (task, get_symbol);
LUA_FUNCTION_DEF (task, get_date);
+LUA_FUNCTION_DEF (task, get_timeval);
LUA_FUNCTION_DEF (task, get_metric_score);
LUA_FUNCTION_DEF (task, get_metric_action);
LUA_FUNCTION_DEF (task, learn_statfile);
@@ -100,6 +101,7 @@ static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_images),
LUA_INTERFACE_DEF (task, get_symbol),
LUA_INTERFACE_DEF (task, get_date),
+ LUA_INTERFACE_DEF (task, get_timeval),
LUA_INTERFACE_DEF (task, get_metric_score),
LUA_INTERFACE_DEF (task, get_metric_action),
LUA_INTERFACE_DEF (task, learn_statfile),
@@ -1169,6 +1171,28 @@ lua_task_get_date (lua_State *L)
}
static gint
+lua_task_get_timeval (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+
+ if (task != NULL) {
+ lua_newtable (L);
+ lua_pushstring (L, "tv_sec");
+ lua_pushnumber (L, (lua_Number)task->tv.tv_sec);
+ lua_settable (L, -3);
+ lua_pushstring (L, "tv_usec");
+ lua_pushnumber (L, (lua_Number)task->tv.tv_usec);
+ lua_settable (L, -3);
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+}
+
+
+static gint
lua_task_learn_statfile (lua_State *L)
{
struct worker_task *task = lua_check_task (L);
diff --git a/src/lua/lua_upstream.c b/src/lua/lua_upstream.c
index 5c72f7949..f9e74e027 100644
--- a/src/lua/lua_upstream.c
+++ b/src/lua/lua_upstream.c
@@ -72,8 +72,8 @@ static const struct luaL_reg upstream_m[] = {
LUA_INTERFACE_DEF (upstream, get_ip_string),
LUA_INTERFACE_DEF (upstream, get_port),
LUA_INTERFACE_DEF (upstream, get_priority),
+ LUA_INTERFACE_DEF (upstream, destroy),
{"__tostring", lua_class_tostring},
- {"__gc", lua_upstream_destroy},
{NULL, NULL}
};
static const struct luaL_reg upstream_f[] = {
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
new file mode 100644
index 000000000..7b1ea3eff
--- /dev/null
+++ b/src/plugins/lua/ratelimit.lua
@@ -0,0 +1,315 @@
+-- A plugin that implements ratelimits using redis or kvstorage server
+
+-- Default port for redis upstreams
+local default_port = 6379
+-- Default settings for limits, 1-st member is burst, second is rate and the third is numeric type
+local settings = {
+ -- Limit for all mail per recipient (burst 100, rate 2 per minute)
+ to = {[1] = 100, [2] = 0.033333333, [3] = 1},
+ -- Limit for all mail per one source ip (burst 30, rate 1.5 per minute)
+ to_ip = {[1] = 30, [2] = 0.025, [3] = 2},
+ -- Limit for all mail per one source ip and from address (burst 20, rate 1 per minute)
+ to_ip_from = {[1] = 20, [2] = 0.01666666667, [3] = 3},
+
+ -- Limit for all bounce mail (burst 10, rate 2 per hour)
+ bounce_to = {[1] = 10, [2] = 0.000555556, [3] = 4},
+ -- Limit for bounce mail per one source ip (burst 5, rate 1 per hour)
+ bounce_to_ip = {[1] = 5 , [2] = 0.000277778, [3] = 5}
+}
+-- Senders that are considered as bounce
+local bounce_senders = {'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon'}
+-- Do not check ratelimits for these senders
+local whitelisted_rcpts = {'postmaster', 'mailer-daemon'}
+local whitelisted_ip = nil
+local upstreams = nil
+
+--- Parse atime and bucket of limit
+local function parse_limit_data(str)
+ local pos,_ = string.find(str, ':')
+ if not pos then
+ return 0, 0
+ else
+ local atime = tonumber(string.sub(str, 1, pos - 1))
+ local bucket = tonumber(string.sub(str, pos + 1))
+ return atime,bucket
+ end
+end
+
+--- Check specific limit inside redis
+local function check_specific_limit (task, limit, key)
+
+ local upstream = upstreams:get_upstream_by_hash(key, task:get_date())
+ --- Called when value was set on server
+ local function rate_set_key_cb(task, err, data)
+ if err then
+ upstream:fail()
+ else
+ upstream:ok()
+ end
+ end
+ --- Called when value is got from server
+ local function rate_get_cb(task, err, data)
+ if data then
+ local atime, bucket = parse_limit_data(data)
+ local tv = task:get_timeval()
+ local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
+ -- Leak messages
+ bucket = bucket - limit[2] * (ntime - atime);
+ if bucket > 0 then
+ local lstr = string.format('%.7f:%.7f', ntime, bucket)
+ rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ 'SET %b %b', key, lstr)
+ if bucket > limit[1] then
+ task:set_pre_result(rspamd_actions['soft reject'], 'Ratelimit exceeded')
+ end
+ else
+ rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ 'DEL %b', key)
+ end
+ end
+ if err then
+ upstream:fail()
+ end
+ end
+ if upstream then
+ rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_get_cb, 'GET %b', key)
+ end
+end
+
+--- Set specific limit inside redis
+local function set_specific_limit (task, limit, key)
+ local upstream = upstreams:get_upstream_by_hash(key, task:get_date())
+ --- Called when value was set on server
+ local function rate_set_key_cb(task, err, data)
+ if err then
+ upstream:fail()
+ else
+ upstream:ok()
+ end
+ end
+ --- Called when value is got from server
+ local function rate_set_cb(task, err, data)
+ if not err and not data then
+ --- Add new entry
+ local tv = task:get_timeval()
+ local atime = tv['tv_usec'] / 1000000. + tv['tv_sec']
+ local lstr = string.format('%.7f:1', atime)
+ rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ 'SET %b %b', key, lstr)
+ elseif data then
+ local atime, bucket = parse_limit_data(data)
+ local tv = task:get_timeval()
+ local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
+ -- Leak messages
+ bucket = bucket - limit[2] * (ntime - atime) + 1;
+ local lstr = string.format('%.7f:%.7f', ntime, bucket)
+ rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ 'SET %b %b', key, lstr)
+ elseif err then
+ upstream:fail()
+ end
+ end
+ if upstream then
+ rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_cb, 'GET %b', key)
+ end
+end
+
+--- Make rate key
+local function make_rate_key(from, to, ip)
+ if from and ip then
+ return string.format('%s:%s:%s', from, to, ip)
+ elseif from then
+ return string.format('%s:%s', from, to)
+ elseif ip then
+ return string.format('%s:%s', to, ip)
+ elseif to then
+ return to
+ else
+ return nil
+ end
+end
+
+--- Check whether this addr is bounce
+local function check_bounce(from)
+ for _,b in ipairs(whitelisted_rcpts) do
+ if b == from then
+ return true
+ end
+ end
+ return false
+end
+
+--- Check or update ratelimit
+local function rate_test_set(task, func)
+
+ -- Returns local part component of address
+ local function get_local_part(str)
+ pos,_ = string.find(str, '@', 0, true)
+ if not pos then
+ return str
+ else
+ return string.sub(str, 1, pos - 1)
+ end
+ end
+
+ -- Get initial task data
+ local ip = task:get_from_ip()
+ if ip and whitelisted_ip then
+ if whitelisted_ip:get_key(ip) then
+ -- Do not check whitelisted ip
+ return
+ end
+ end
+ -- Parse all rcpts
+ local rcpts = task:get_recipients()
+ local rcpts_user = {}
+ if not rcpts then
+ rcpts = task:get_recipients_headers()
+ end
+ if rcpts then
+ for i,r in ipairs(rcpts) do
+ rcpts_user[i] = get_local_part(r['addr'])
+ end
+ end
+ -- Parse from
+ local from = task:get_from()
+ local from_user = ''
+ if not from then
+ from = task:get_from_headers()
+ end
+ if from then
+ from_user = get_local_part(from[1]['addr'])
+ end
+
+ if not from_user or not rcpts_user[1] then
+ -- Nothing to check
+ return
+ end
+
+ local is_bounce = check_bounce(from_user)
+
+ for _,r in ipairs(rcpts) do
+ if is_bounce then
+ -- Bounce specific limit
+ func(task, settings['bounce_to'], make_rate_key ('<>', r['addr'], nil))
+ if ip then
+ func(task, settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip))
+ end
+ end
+ func(task, settings['to'], make_rate_key (nil, r['addr'], nil))
+ if ip then
+ func(task, settings['to_ip'], make_rate_key (nil, r['addr'], ip))
+ func(task, settings['to_ip_from'], make_rate_key (from[1]['addr'], r['addr'], ip))
+ end
+ end
+end
+
+--- Check limit
+local function rate_test(task)
+ rate_test_set(task, check_specific_limit)
+end
+--- Update limit
+local function rate_set(task)
+ rate_test_set(task, set_specific_limit)
+end
+
+
+--- Utility function for split string to table
+local function split(str, delim, maxNb)
+ -- Eliminate bad cases...
+ if string.find(str, delim) == nil then
+ return { str }
+ end
+ if maxNb == nil or maxNb < 1 then
+ maxNb = 0 -- No limit
+ end
+ local result = {}
+ local pat = "(.-)" .. delim .. "()"
+ local nb = 0
+ local lastPos
+ for part, pos in string.gfind(str, pat) do
+ nb = nb + 1
+ result[nb] = part
+ lastPos = pos
+ if nb == maxNb then break end
+ end
+ -- Handle the last field
+ if nb ~= maxNb then
+ result[nb + 1] = string.sub(str, lastPos)
+ end
+ return result
+end
+
+--- Parse a single limit description
+local function parse_limit(str)
+ local params = split(str, ':', 0)
+
+ local function set_limit(limit, burst, rate)
+ limit[1] = tonumber(burst)
+ limit[2] = tonumber(rate)
+ end
+
+ if table.maxn(params) ~= 3 then
+ rspamd_logger.err('invalid limit definition: ' .. str)
+ return
+ end
+
+ if params[1] == 'to' then
+ set_limit(settings['to'], params[2], params[3])
+ elseif params[1] == 'to_ip' then
+ set_limit(settings['to_ip'], params[2], params[3])
+ elseif params[1] == 'to_ip_from' then
+ set_limit(settings['to_ip_from'], params[2], params[3])
+ elseif params[1] == 'bounce_to' then
+ set_limit(settings['bounce_to'], params[2], params[3])
+ elseif params[1] == 'bounce_to_ip' then
+ set_limit(settings['bounce_to_ip'], params[2], params[3])
+ else
+ rspamd_logger.err('invalid limit type: ' .. params[1])
+ end
+end
+
+-- Registration
+if rspamd_config:get_api_version() >= 9 then
+ rspamd_config:register_module_option('ratelimit', 'servers', 'string')
+ rspamd_config:register_module_option('ratelimit', 'bounce_senders', 'string')
+ rspamd_config:register_module_option('ratelimit', 'whitelisted_rcpts', 'string')
+ rspamd_config:register_module_option('ratelimit', 'whitelisted_ip', 'map')
+ rspamd_config:register_module_option('ratelimit', 'limit', 'string')
+end
+
+local function parse_whitelisted_rcpts(str)
+
+end
+
+local opts = rspamd_config:get_all_opt('ratelimit')
+if opts then
+ local rates = opts['limit']
+ if rates and type(rates) == 'table' then
+ for _,r in ipairs(rates) do
+ parse_limit(r)
+ end
+ elseif rates and type(rates) == 'string' then
+ parse_limit(rates)
+ end
+
+ if opts['whitelisted_rcpts'] and type(opts['whitelisted_rcpts']) == 'string' then
+ whitelisted_rcpts = split(opts['whitelisted_rcpts'], ',')
+ end
+
+ if opts['whitelisted_ip'] then
+ whitelisted_ip = rspamd_config:add_hash_map (opts['whitelisted_ip'])
+ end
+
+ if not opts['servers'] then
+ rspamd_logger.err('no servers are specified')
+ else
+ upstreams = upstream_list.create(opts['servers'], default_port)
+ if not upstreams then
+ rspamd_logger.err('no servers are specified')
+ else
+ rspamd_config:register_pre_filter(rate_test)
+ rspamd_config:register_post_filter(rate_set)
+ end
+ end
+end