aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2014-10-30 22:29:28 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2014-10-30 22:29:28 +0000
commitff8a6010e8798bf1242246705b88176538e626ff (patch)
tree057826d8877e77798ade77326298a323dd6403b7 /src
parent08d2fa41695932bfaf199362a0dac27002ffa922 (diff)
downloadrspamd-ff8a6010e8798bf1242246705b88176538e626ff.tar.gz
rspamd-ff8a6010e8798bf1242246705b88176538e626ff.zip
Fix upstreams in ratelimit.
Diffstat (limited to 'src')
-rw-r--r--src/lua/lua_ip.c23
-rw-r--r--src/plugins/lua/ratelimit.lua20
2 files changed, 34 insertions, 9 deletions
diff --git a/src/lua/lua_ip.c b/src/lua/lua_ip.c
index ea4862e93..b541727b2 100644
--- a/src/lua/lua_ip.c
+++ b/src/lua/lua_ip.c
@@ -163,6 +163,13 @@ LUA_FUNCTION_DEF (ip, equal);
*/
LUA_FUNCTION_DEF (ip, copy);
+/**
+ * @method ip:get_port()
+ * Returns associated port for this IP address
+ * @return {number} port number or nil
+ */
+LUA_FUNCTION_DEF (ip, get_port);
+
static const struct luaL_reg iplib_m[] = {
LUA_INTERFACE_DEF (ip, to_string),
LUA_INTERFACE_DEF (ip, to_table),
@@ -170,6 +177,7 @@ static const struct luaL_reg iplib_m[] = {
LUA_INTERFACE_DEF (ip, str_octets),
LUA_INTERFACE_DEF (ip, inversed_str_octets),
LUA_INTERFACE_DEF (ip, get_version),
+ LUA_INTERFACE_DEF (ip, get_port),
LUA_INTERFACE_DEF (ip, is_valid),
LUA_INTERFACE_DEF (ip, apply_mask),
LUA_INTERFACE_DEF (ip, copy),
@@ -350,6 +358,21 @@ lua_ip_to_string (lua_State *L)
}
static gint
+lua_ip_get_port (lua_State *L)
+{
+ struct rspamd_lua_ip *ip = lua_check_ip (L, 1);
+
+ if (ip != NULL && ip->is_valid) {
+ lua_pushnumber (L, rspamd_inet_address_get_port (&ip->addr));
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+}
+
+static gint
lua_ip_from_string (lua_State *L)
{
struct rspamd_lua_ip *ip;
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
index 1d9f15a6d..e79f68420 100644
--- a/src/plugins/lua/ratelimit.lua
+++ b/src/plugins/lua/ratelimit.lua
@@ -47,7 +47,8 @@ end
--- Check specific limit inside redis
local function check_specific_limit (task, limit, key)
- local upstream = upstreams:get_upstream_by_hash(key, task:get_date())
+ local upstream = upstreams:get_upstream_by_hash(key)
+ local addr = upstream:get_addr()
--- Called when value was set on server
local function rate_set_key_cb(task, err, data)
if err then
@@ -67,13 +68,13 @@ local function check_specific_limit (task, limit, key)
bucket = bucket - limit[2] * (ntime - atime);
if bucket > 0 then
local lstr = string.format('%.7f:%.7f', ntime, bucket)
- rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
'SET %b %b', key, lstr)
if bucket > limit[1] then
task:set_pre_result('soft reject', 'Ratelimit exceeded: ' .. key)
end
else
- rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
'DEL %b', key)
end
end
@@ -83,13 +84,14 @@ local function check_specific_limit (task, limit, key)
end
end
if upstream then
- rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_get_cb, 'GET %b', key)
+ rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_get_cb, 'GET %b', key)
end
end
--- Set specific limit inside redis
local function set_specific_limit (task, limit, key)
- local upstream = upstreams:get_upstream_by_hash(key, task:get_date())
+ local upstream = upstreams:get_upstream_by_hash(key)
+ local addr = upstream:get_addr()
--- Called when value was set on server
local function rate_set_key_cb(task, err, data)
if err then
@@ -105,8 +107,8 @@ local function set_specific_limit (task, limit, key)
--- Add new entry
local tv = task:get_timeval()
local atime = tv['tv_usec'] / 1000000. + tv['tv_sec']
- local lstr = string.format('%.7f:1', atime)
- rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ local lstr = string.format('%.7f:1', atime)
+ rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
'SET %b %b', key, lstr)
elseif data then
local atime, bucket = parse_limit_data(data)
@@ -115,7 +117,7 @@ local function set_specific_limit (task, limit, key)
-- Leak messages
bucket = bucket - limit[2] * (ntime - atime) + 1;
local lstr = string.format('%.7f:%.7f', ntime, bucket)
- rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb,
+ rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb,
'SET %b %b', key, lstr)
elseif err then
rspamd_logger.info('got error while setting limit: ' .. err)
@@ -123,7 +125,7 @@ local function set_specific_limit (task, limit, key)
end
end
if upstream then
- rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_cb, 'GET %b', key)
+ rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_cb, 'GET %b', key)
end
end