diff options
-rw-r--r-- | src/lua/lua_ip.c | 23 | ||||
-rw-r--r-- | src/plugins/lua/ratelimit.lua | 20 |
2 files changed, 34 insertions, 9 deletions
diff --git a/src/lua/lua_ip.c b/src/lua/lua_ip.c index ea4862e93..b541727b2 100644 --- a/src/lua/lua_ip.c +++ b/src/lua/lua_ip.c @@ -163,6 +163,13 @@ LUA_FUNCTION_DEF (ip, equal); */ LUA_FUNCTION_DEF (ip, copy); +/** + * @method ip:get_port() + * Returns associated port for this IP address + * @return {number} port number or nil + */ +LUA_FUNCTION_DEF (ip, get_port); + static const struct luaL_reg iplib_m[] = { LUA_INTERFACE_DEF (ip, to_string), LUA_INTERFACE_DEF (ip, to_table), @@ -170,6 +177,7 @@ static const struct luaL_reg iplib_m[] = { LUA_INTERFACE_DEF (ip, str_octets), LUA_INTERFACE_DEF (ip, inversed_str_octets), LUA_INTERFACE_DEF (ip, get_version), + LUA_INTERFACE_DEF (ip, get_port), LUA_INTERFACE_DEF (ip, is_valid), LUA_INTERFACE_DEF (ip, apply_mask), LUA_INTERFACE_DEF (ip, copy), @@ -350,6 +358,21 @@ lua_ip_to_string (lua_State *L) } static gint +lua_ip_get_port (lua_State *L) +{ + struct rspamd_lua_ip *ip = lua_check_ip (L, 1); + + if (ip != NULL && ip->is_valid) { + lua_pushnumber (L, rspamd_inet_address_get_port (&ip->addr)); + } + else { + lua_pushnil (L); + } + + return 1; +} + +static gint lua_ip_from_string (lua_State *L) { struct rspamd_lua_ip *ip; diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index 1d9f15a6d..e79f68420 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -47,7 +47,8 @@ end --- Check specific limit inside redis local function check_specific_limit (task, limit, key) - local upstream = upstreams:get_upstream_by_hash(key, task:get_date()) + local upstream = upstreams:get_upstream_by_hash(key) + local addr = upstream:get_addr() --- Called when value was set on server local function rate_set_key_cb(task, err, data) if err then @@ -67,13 +68,13 @@ local function check_specific_limit (task, limit, key) bucket = bucket - limit[2] * (ntime - atime); if bucket > 0 then local lstr = string.format('%.7f:%.7f', ntime, bucket) - rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, + rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 'SET %b %b', key, lstr) if bucket > limit[1] then task:set_pre_result('soft reject', 'Ratelimit exceeded: ' .. key) end else - rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, + rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 'DEL %b', key) end end @@ -83,13 +84,14 @@ local function check_specific_limit (task, limit, key) end end if upstream then - rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_get_cb, 'GET %b', key) + rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_get_cb, 'GET %b', key) end end --- Set specific limit inside redis local function set_specific_limit (task, limit, key) - local upstream = upstreams:get_upstream_by_hash(key, task:get_date()) + local upstream = upstreams:get_upstream_by_hash(key) + local addr = upstream:get_addr() --- Called when value was set on server local function rate_set_key_cb(task, err, data) if err then @@ -105,8 +107,8 @@ local function set_specific_limit (task, limit, key) --- Add new entry local tv = task:get_timeval() local atime = tv['tv_usec'] / 1000000. + tv['tv_sec'] - local lstr = string.format('%.7f:1', atime) - rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, + local lstr = string.format('%.7f:1', atime) + rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 'SET %b %b', key, lstr) elseif data then local atime, bucket = parse_limit_data(data) @@ -115,7 +117,7 @@ local function set_specific_limit (task, limit, key) -- Leak messages bucket = bucket - limit[2] * (ntime - atime) + 1; local lstr = string.format('%.7f:%.7f', ntime, bucket) - rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, + rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 'SET %b %b', key, lstr) elseif err then rspamd_logger.info('got error while setting limit: ' .. err) @@ -123,7 +125,7 @@ local function set_specific_limit (task, limit, key) end end if upstream then - rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_cb, 'GET %b', key) + rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_cb, 'GET %b', key) end end |