]> source.dussan.org Git - rspamd.git/commitdiff
* Add ratelimit plugin
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Fri, 6 Apr 2012 16:46:55 +0000 (20:46 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Fri, 6 Apr 2012 16:46:55 +0000 (20:46 +0400)
Some polishing of lua task api.

src/lua/lua_task.c
src/lua/lua_upstream.c
src/plugins/lua/ratelimit.lua [new file with mode: 0644]

index d205016e08eb803645f25c7eb011ed2de6529d4c..7ca1c58df01d49c6a17da8505fe2a639d18b25dd 100644 (file)
@@ -69,6 +69,7 @@ LUA_FUNCTION_DEF (task, get_helo);
 LUA_FUNCTION_DEF (task, get_images);
 LUA_FUNCTION_DEF (task, get_symbol);
 LUA_FUNCTION_DEF (task, get_date);
+LUA_FUNCTION_DEF (task, get_timeval);
 LUA_FUNCTION_DEF (task, get_metric_score);
 LUA_FUNCTION_DEF (task, get_metric_action);
 LUA_FUNCTION_DEF (task, learn_statfile);
@@ -100,6 +101,7 @@ static const struct luaL_reg    tasklib_m[] = {
        LUA_INTERFACE_DEF (task, get_images),
        LUA_INTERFACE_DEF (task, get_symbol),
        LUA_INTERFACE_DEF (task, get_date),
+       LUA_INTERFACE_DEF (task, get_timeval),
        LUA_INTERFACE_DEF (task, get_metric_score),
        LUA_INTERFACE_DEF (task, get_metric_action),
        LUA_INTERFACE_DEF (task, learn_statfile),
@@ -1168,6 +1170,28 @@ lua_task_get_date (lua_State *L)
        return 1;
 }
 
+static gint
+lua_task_get_timeval (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+
+       if (task != NULL) {
+               lua_newtable (L);
+               lua_pushstring (L, "tv_sec");
+               lua_pushnumber (L, (lua_Number)task->tv.tv_sec);
+               lua_settable (L, -3);
+               lua_pushstring (L, "tv_usec");
+               lua_pushnumber (L, (lua_Number)task->tv.tv_usec);
+               lua_settable (L, -3);
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+}
+
+
 static gint
 lua_task_learn_statfile (lua_State *L)
 {
index 5c72f7949e58c7f794ae568561508e0df32e47ba..f9e74e02708b6bc12c72c534ea0d4cfa289aaaef 100644 (file)
@@ -72,8 +72,8 @@ static const struct luaL_reg  upstream_m[] = {
        LUA_INTERFACE_DEF (upstream, get_ip_string),
        LUA_INTERFACE_DEF (upstream, get_port),
        LUA_INTERFACE_DEF (upstream, get_priority),
+       LUA_INTERFACE_DEF (upstream, destroy),
        {"__tostring", lua_class_tostring},
-       {"__gc", lua_upstream_destroy},
        {NULL, NULL}
 };
 static const struct luaL_reg  upstream_f[] = {
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
new file mode 100644 (file)
index 0000000..7b1ea3e
--- /dev/null
@@ -0,0 +1,315 @@
+-- A plugin that implements ratelimits using redis or kvstorage server
+
+-- Default port for redis upstreams
+local default_port = 6379
+-- Default settings for limits, 1-st member is burst, second is rate and the third is numeric type 
+local settings = {
+       -- Limit for all mail per recipient (burst 100, rate 2 per minute)
+       to = {[1] = 100, [2] = 0.033333333, [3] = 1},
+       -- Limit for all mail per one source ip (burst 30, rate 1.5 per minute)  
+       to_ip = {[1] = 30, [2] = 0.025, [3] = 2},
+       -- Limit for all mail per one source ip and from address (burst 20, rate 1 per minute)
+       to_ip_from = {[1] = 20, [2] = 0.01666666667, [3] = 3}, 
+       
+       -- Limit for all bounce mail (burst 10, rate 2 per hour)
+       bounce_to = {[1] = 10, [2] = 0.000555556, [3] = 4}, 
+       -- Limit for bounce mail per one source ip (burst 5, rate 1 per hour)
+       bounce_to_ip = {[1] = 5 , [2] = 0.000277778, [3] = 5} 
+}
+-- Senders that are considered as bounce
+local bounce_senders = {'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon'}
+-- Do not check ratelimits for these senders
+local whitelisted_rcpts = {'postmaster', 'mailer-daemon'}
+local whitelisted_ip = nil
+local upstreams = nil
+
+--- Parse atime and bucket of limit
+local function parse_limit_data(str)
+       local pos,_ = string.find(str, ':')
+       if not pos then
+               return 0, 0
+       else
+               local atime = tonumber(string.sub(str, 1, pos - 1))
+               local bucket = tonumber(string.sub(str, pos + 1))
+               return atime,bucket
+       end
+end
+
+--- Check specific limit inside redis
+local function check_specific_limit (task, limit, key)
+
+       local upstream = upstreams:get_upstream_by_hash(key, task:get_date())
+       --- Called when value was set on server
+       local function rate_set_key_cb(task, err, data)
+               if err then
+                       upstream:fail()
+               else
+                       upstream:ok()
+               end
+       end
+       --- Called when value is got from server
+       local function rate_get_cb(task, err, data)
+               if data then
+                       local atime, bucket = parse_limit_data(data)
+                       local tv = task:get_timeval()
+                       local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
+                       -- Leak messages
+                       bucket = bucket - limit[2] * (ntime - atime);
+                       if bucket > 0 then
+                               local lstr = string.format('%.7f:%.7f', ntime, bucket)
+                               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                                                       'SET %b %b', key, lstr)
+                               if bucket > limit[1] then
+                                       task:set_pre_result(rspamd_actions['soft reject'], 'Ratelimit exceeded')
+                               end
+                       else
+                               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                                                       'DEL %b', key)
+                       end
+               end
+               if err then
+                       upstream:fail()
+               end     
+       end
+       if upstream then
+               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_get_cb, 'GET %b', key)
+       end
+end
+
+--- Set specific limit inside redis
+local function set_specific_limit (task, limit, key)
+       local upstream = upstreams:get_upstream_by_hash(key,  task:get_date())
+       --- Called when value was set on server
+       local function rate_set_key_cb(task, err, data)
+               if err then
+                       upstream:fail()
+               else
+                       upstream:ok()
+               end
+       end
+       --- Called when value is got from server
+       local function rate_set_cb(task, err, data)
+               if not err and not data then
+                       --- Add new entry
+                       local tv = task:get_timeval()
+                       local atime = tv['tv_usec'] / 1000000. + tv['tv_sec']
+                       local lstr = string.format('%.7f:1', atime)
+                       rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                                                       'SET %b %b', key, lstr)
+               elseif data then
+                       local atime, bucket = parse_limit_data(data)
+                       local tv = task:get_timeval()
+                       local ntime = tv['tv_usec'] / 1000000. + tv['tv_sec']
+                       -- Leak messages
+                       bucket = bucket - limit[2] * (ntime - atime) + 1;
+                       local lstr = string.format('%.7f:%.7f', ntime, bucket)
+                       rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                                                       'SET %b %b', key, lstr)
+               elseif err then
+                       upstream:fail()
+               end
+       end
+       if upstream then
+               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_cb, 'GET %b', key)
+       end
+end
+
+--- Make rate key
+local function make_rate_key(from, to, ip)
+       if from and ip then
+               return string.format('%s:%s:%s', from, to, ip)
+       elseif from then
+               return string.format('%s:%s', from, to)
+       elseif ip then
+               return string.format('%s:%s', to, ip)
+       elseif to then
+               return to
+       else
+               return nil
+       end
+end
+
+--- Check whether this addr is bounce
+local function check_bounce(from)
+       for _,b in ipairs(whitelisted_rcpts) do
+               if b == from then
+                       return true
+               end
+       end
+       return false
+end
+
+--- Check or update ratelimit
+local function rate_test_set(task, func)
+       
+       -- Returns local part component of address
+       local function get_local_part(str)
+               pos,_ = string.find(str, '@', 0, true)
+               if not pos then
+                       return str
+               else
+                       return string.sub(str, 1, pos - 1)
+               end
+       end
+       
+       -- Get initial task data
+       local ip = task:get_from_ip()
+       if ip and whitelisted_ip then
+               if whitelisted_ip:get_key(ip) then
+                       -- Do not check whitelisted ip
+                       return
+               end
+       end
+       -- Parse all rcpts 
+       local rcpts = task:get_recipients()
+       local rcpts_user = {}
+       if not rcpts then
+               rcpts = task:get_recipients_headers()
+       end
+       if rcpts then
+               for i,r in ipairs(rcpts) do
+                       rcpts_user[i] = get_local_part(r['addr'])
+               end
+       end
+       -- Parse from
+       local from = task:get_from()
+       local from_user = ''
+       if not from then
+               from = task:get_from_headers()
+       end
+       if from then
+               from_user = get_local_part(from[1]['addr'])
+       end
+       
+       if not from_user or not rcpts_user[1] then
+               -- Nothing to check
+               return
+       end
+       
+       local is_bounce = check_bounce(from_user)
+       
+       for _,r in ipairs(rcpts) do
+               if is_bounce then
+                       -- Bounce specific limit
+                       func(task, settings['bounce_to'], make_rate_key ('<>', r['addr'], nil))
+                       if ip then
+                               func(task, settings['bounce_to_ip'], make_rate_key ('<>', r['addr'], ip))
+                       end
+               end
+               func(task, settings['to'], make_rate_key (nil, r['addr'], nil))
+               if ip then
+                       func(task, settings['to_ip'], make_rate_key (nil, r['addr'], ip))
+                       func(task, settings['to_ip_from'], make_rate_key (from[1]['addr'], r['addr'], ip))
+               end
+       end
+end
+
+--- Check limit
+local function rate_test(task)
+       rate_test_set(task, check_specific_limit)
+end
+--- Update limit
+local function rate_set(task)
+       rate_test_set(task, set_specific_limit)
+end
+
+
+--- Utility function for split string to table
+local function split(str, delim, maxNb)
+       -- Eliminate bad cases...
+       if string.find(str, delim) == nil then
+               return { str }
+       end
+       if maxNb == nil or maxNb < 1 then
+               maxNb = 0    -- No limit
+       end
+       local result = {}
+       local pat = "(.-)" .. delim .. "()"
+       local nb = 0
+       local lastPos
+       for part, pos in string.gfind(str, pat) do
+               nb = nb + 1
+               result[nb] = part
+               lastPos = pos
+               if nb == maxNb then break end
+       end
+       -- Handle the last field
+       if nb ~= maxNb then
+               result[nb + 1] = string.sub(str, lastPos)
+       end
+       return result
+end
+
+--- Parse a single limit description
+local function parse_limit(str)
+       local params = split(str, ':', 0)
+       
+       local function set_limit(limit, burst, rate)
+               limit[1] = tonumber(burst)
+               limit[2] = tonumber(rate)
+       end
+       
+       if table.maxn(params) ~= 3 then
+               rspamd_logger.err('invalid limit definition: ' .. str)
+               return
+       end
+       
+       if params[1] == 'to' then
+               set_limit(settings['to'], params[2], params[3])
+       elseif params[1] == 'to_ip' then
+               set_limit(settings['to_ip'], params[2], params[3])
+       elseif params[1] == 'to_ip_from' then
+               set_limit(settings['to_ip_from'], params[2], params[3])
+       elseif params[1] == 'bounce_to' then
+               set_limit(settings['bounce_to'], params[2], params[3])
+       elseif params[1] == 'bounce_to_ip' then
+               set_limit(settings['bounce_to_ip'], params[2], params[3])
+       else
+               rspamd_logger.err('invalid limit type: ' .. params[1])
+       end
+end
+
+-- Registration
+if rspamd_config:get_api_version() >= 9 then
+       rspamd_config:register_module_option('ratelimit', 'servers', 'string')
+       rspamd_config:register_module_option('ratelimit', 'bounce_senders', 'string')
+       rspamd_config:register_module_option('ratelimit', 'whitelisted_rcpts', 'string')
+       rspamd_config:register_module_option('ratelimit', 'whitelisted_ip', 'map')
+       rspamd_config:register_module_option('ratelimit', 'limit', 'string')
+end
+
+local function parse_whitelisted_rcpts(str)
+       
+end
+
+local opts =  rspamd_config:get_all_opt('ratelimit')
+if opts then
+       local rates = opts['limit']
+       if rates and type(rates) == 'table' then
+               for _,r in ipairs(rates) do
+                       parse_limit(r)
+               end
+       elseif rates and type(rates) == 'string' then
+               parse_limit(rates)
+       end
+       
+       if opts['whitelisted_rcpts'] and type(opts['whitelisted_rcpts']) == 'string' then
+               whitelisted_rcpts = split(opts['whitelisted_rcpts'], ',')
+       end
+       
+       if opts['whitelisted_ip'] then
+               whitelisted_ip = rspamd_config:add_hash_map (opts['whitelisted_ip'])
+       end
+       
+       if not opts['servers'] then
+               rspamd_logger.err('no servers are specified')
+       else
+               upstreams = upstream_list.create(opts['servers'], default_port)
+               if not upstreams then
+                       rspamd_logger.err('no servers are specified')
+               else
+                       rspamd_config:register_pre_filter(rate_test)
+                       rspamd_config:register_post_filter(rate_set)
+               end
+       end
+end