]> source.dussan.org Git - rspamd.git/commitdiff
Fix upstreams in ratelimit.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 30 Oct 2014 22:29:28 +0000 (22:29 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 30 Oct 2014 22:29:28 +0000 (22:29 +0000)
src/lua/lua_ip.c
src/plugins/lua/ratelimit.lua

index ea4862e93e333c2ac0aa5b52d7a2628833fad0c7..b541727b2b672dc8d6af16a13255f1672ffc0e28 100644 (file)
@@ -163,6 +163,13 @@ LUA_FUNCTION_DEF (ip, equal);
  */
 LUA_FUNCTION_DEF (ip, copy);
 
+/**
+ * @method ip:get_port()
+ * Returns associated port for this IP address
+ * @return {number} port number or nil
+ */
+LUA_FUNCTION_DEF (ip, get_port);
+
 static const struct luaL_reg iplib_m[] = {
        LUA_INTERFACE_DEF (ip, to_string),
        LUA_INTERFACE_DEF (ip, to_table),
@@ -170,6 +177,7 @@ static const struct luaL_reg iplib_m[] = {
        LUA_INTERFACE_DEF (ip, str_octets),
        LUA_INTERFACE_DEF (ip, inversed_str_octets),
        LUA_INTERFACE_DEF (ip, get_version),
+       LUA_INTERFACE_DEF (ip, get_port),
        LUA_INTERFACE_DEF (ip, is_valid),
        LUA_INTERFACE_DEF (ip, apply_mask),
        LUA_INTERFACE_DEF (ip, copy),
@@ -349,6 +357,21 @@ lua_ip_to_string (lua_State *L)
        return 1;
 }
 
+static gint
+lua_ip_get_port (lua_State *L)
+{
+       struct rspamd_lua_ip *ip = lua_check_ip (L, 1);
+
+       if (ip != NULL && ip->is_valid) {
+               lua_pushnumber (L, rspamd_inet_address_get_port (&ip->addr));
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+}
+
 static gint
 lua_ip_from_string (lua_State *L)
 {
index 1d9f15a6d2e0ee94264a4db1bf82055380a40efd..e79f6842035b7c3150d7f84a558e42ec0201051f 100644 (file)
@@ -47,7 +47,8 @@ end
 --- Check specific limit inside redis
 local function check_specific_limit (task, limit, key)
 
-       local upstream = upstreams:get_upstream_by_hash(key, task:get_date())
+       local upstream = upstreams:get_upstream_by_hash(key)
+       local addr = upstream:get_addr()
        --- Called when value was set on server
        local function rate_set_key_cb(task, err, data)
                if err then
@@ -67,13 +68,13 @@ local function check_specific_limit (task, limit, key)
                        bucket = bucket - limit[2] * (ntime - atime);
                        if bucket > 0 then
                                local lstr = string.format('%.7f:%.7f', ntime, bucket)
-                               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                               rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 
                                                        'SET %b %b', key, lstr)
                                if bucket > limit[1] then
                                        task:set_pre_result('soft reject', 'Ratelimit exceeded: ' .. key)
                                end
                        else
-                               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                               rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 
                                                        'DEL %b', key)
                        end
                end
@@ -83,13 +84,14 @@ local function check_specific_limit (task, limit, key)
                end     
        end
        if upstream then
-               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_get_cb, 'GET %b', key)
+               rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_get_cb, 'GET %b', key)
        end
 end
 
 --- Set specific limit inside redis
 local function set_specific_limit (task, limit, key)
-       local upstream = upstreams:get_upstream_by_hash(key,  task:get_date())
+       local upstream = upstreams:get_upstream_by_hash(key)
+       local addr = upstream:get_addr()
        --- Called when value was set on server
        local function rate_set_key_cb(task, err, data)
                if err then
@@ -105,8 +107,8 @@ local function set_specific_limit (task, limit, key)
                        --- Add new entry
                        local tv = task:get_timeval()
                        local atime = tv['tv_usec'] / 1000000. + tv['tv_sec']
-                       local lstr = string.format('%.7f:1', atime)
-                       rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                       local lstr = string.format('%.7f:1', atime)     
+                       rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 
                                                        'SET %b %b', key, lstr)
                elseif data then
                        local atime, bucket = parse_limit_data(data)
@@ -115,7 +117,7 @@ local function set_specific_limit (task, limit, key)
                        -- Leak messages
                        bucket = bucket - limit[2] * (ntime - atime) + 1;
                        local lstr = string.format('%.7f:%.7f', ntime, bucket)
-                       rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_key_cb, 
+                       rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_key_cb, 
                                                        'SET %b %b', key, lstr)
                elseif err then
                  rspamd_logger.info('got error while setting limit: ' .. err)
@@ -123,7 +125,7 @@ local function set_specific_limit (task, limit, key)
                end
        end
        if upstream then
-               rspamd_redis.make_request(task, upstream:get_ip_string(), upstream:get_port(), rate_set_cb, 'GET %b', key)
+               rspamd_redis.make_request(task, addr:to_string(), addr:get_port(), rate_set_cb, 'GET %b', key)
        end
 end