aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua
diff options
context:
space:
mode:
authorAndrew Lewis <nerf@judo.za.org>2017-08-29 13:07:13 +0200
committerAndrew Lewis <nerf@judo.za.org>2017-08-29 14:29:15 +0200
commitd7363dfa0f54a0d05739936b607c6b1c8b750026 (patch)
treefffa5270d9e4784a7b7cb09b73f63a88aa412151 /src/plugins/lua
parent6ae3b5a35a778fe532ce1964d90f85c36f99ffc9 (diff)
downloadrspamd-d7363dfa0f54a0d05739936b607c6b1c8b750026.tar.gz
rspamd-d7363dfa0f54a0d05739936b607c6b1c8b750026.zip
[Feature] Ratelimit: support fetching limits from Redis
Diffstat (limited to 'src/plugins/lua')
-rw-r--r--src/plugins/lua/ratelimit.lua221
1 files changed, 140 insertions, 81 deletions
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
index 2516d1844..b05ed0114 100644
--- a/src/plugins/lua/ratelimit.lua
+++ b/src/plugins/lua/ratelimit.lua
@@ -375,7 +375,86 @@ local function dynamic_rate_key(task, rtype)
end
end
-local function get_buckets(task)
+local function process_buckets(task, buckets)
+ if not buckets then return end
+ local function rl_redis_cb(err, data)
+ if err then
+ rspamd_logger.infox(task, 'got error while setting limit: %1', err)
+ end
+ if not data then return end
+ if data[1] == 1 then
+ rspamd_logger.infox(task,
+ 'ratelimit "%s" exceeded',
+ data[2])
+ task:set_pre_result('soft reject',
+ message_func(task, data[2]))
+ end
+ end
+ local function rl_symbol_redis_cb(err, data)
+ if err then
+ rspamd_logger.infox(task, 'got error while setting limit: %1', err)
+ end
+ if not data then return end
+ for i, b in ipairs(data) do
+ task:insert_result(ratelimit_symbol, b[2], string.format('%s:%s:%s', i, b[1], b[2]))
+ end
+ end
+ local redis_cb = rl_redis_cb
+ if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end
+ local args = {redis_script_sha, #buckets}
+ for _, bucket in ipairs(buckets) do
+ table.insert(args, bucket[2])
+ end
+ for _, bucket in ipairs(buckets) do
+ if use_ip_score then
+ local asn_score,total_asn,
+ country_score,total_country,
+ ipnet_score,total_ipnet,
+ ip_score, total_ip = task:get_mempool():get_variable('ip_score',
+ 'double,double,double,double,double,double,double,double')
+ local key_keywords = rspamd_str_split(bucket[2], '_')
+ local has_asn, has_ip = false, false
+ for _, v in ipairs(key_keywords) do
+ if v == "asn" then has_asn = true end
+ if v == "ip" then has_ip = true end
+ if has_ip and has_asn then break end
+ end
+ if has_asn and not has_ip then
+ bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
+ elseif has_ip then
+ if total_ip and total_ip > ip_score_lower_bound then
+ bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
+ elseif total_ipnet and total_ipnet > ip_score_lower_bound then
+ bucket[1][2] = resize_element(ipnet_score, total_ipnet, bucket[1][2])
+ elseif total_asn and total_asn > ip_score_lower_bound then
+ bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
+ elseif total_country and total_country > ip_score_lower_bound then
+ bucket[1][2] = resize_element(country_score, total_country, bucket[1][2])
+ else
+ bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
+ end
+ end
+ end
+ table.insert(args, bucket[1][1])
+ table.insert(args, bucket[1][2])
+ end
+ table.insert(args, rspamd_util.get_time())
+ table.insert(args, task:get_queue_id() or task:get_uid())
+ local ret = rspamd_redis_make_request(task,
+ redis_params, -- connect params
+ nil, -- hash key
+ true, -- is write
+ redis_cb, --callback
+ 'evalsha', -- command
+ args -- arguments
+ )
+ if not ret then
+ rspamd_logger.errx(task, 'got error connecting to redis')
+ end
+end
+
+local function ratelimit_cb(task)
+ if rspamd_lua_utils.is_rspamc_or_controller(task) then return end
local args = {}
-- Get initial task data
local ip = task:get_from_ip()
@@ -412,6 +491,38 @@ local function get_buckets(task)
end
end
+ local redis_keys = {}
+ local redis_keys_rev = {}
+ local function collect_redis_keys()
+ local function collect_cb(err, data)
+ if err then
+ rspamd_logger.errx(task, 'redis error: %1', err)
+ else
+ for i, d in ipairs(data) do
+ if type(d) == 'string' then
+ local plim, size = parse_string_limit(d)
+ if plim then
+ table.insert(args, {{plim, size}, redis_keys_rev[i]})
+ end
+ end
+ end
+ return process_buckets(task, args)
+ end
+ end
+ local requested_keys = rspamd_redis_make_request(task,
+ redis_params, -- connect params
+ nil, -- hash key
+ true, -- is write
+ collect_cb, --callback
+ 'MGET', -- command
+ redis_keys -- arguments
+ )
+ if not requested_keys then
+ rspamd_logger.errx(task, 'got error connecting to redis')
+ return process_buckets(task, args)
+ end
+ end
+
local rate_key
for k in pairs(settings) do
rate_key = dynamic_rate_key(task, k)
@@ -426,6 +537,14 @@ local function get_buckets(task)
local plim, size = parse_string_limit(r)
if plim then
table.insert(args, {{plim, size}, rk})
+ else
+ local rkey = string.match(settings[k], 'redis:(.*)')
+ if rkey then
+ table.insert(redis_keys, rkey)
+ redis_keys_rev[#redis_keys] = rk
+ else
+ rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k])
+ end
end
end
end
@@ -439,97 +558,37 @@ local function get_buckets(task)
local plim, size = parse_string_limit(r)
if plim then
table.insert(args, {{plim, size}, rate_key})
+ else
+ local rkey = string.match(settings[k], 'redis:(.*)')
+ if rkey then
+ table.insert(redis_keys, rkey)
+ redis_keys_rev[#redis_keys] = rate_key
+ else
+ rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k])
+ end
end
end
elseif type(settings[k]) == 'table' then
for _, rl in ipairs(settings[k]) do
table.insert(args, {{rl[1], rl[2]}, rate_key})
end
+ elseif type(settings[k]) == 'string' then
+ local rkey = string.match(settings[k], 'redis:(.*)')
+ if rkey then
+ table.insert(redis_keys, rkey)
+ redis_keys_rev[#redis_keys] = rate_key
+ else
+ rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k])
+ end
end
end
end
end
- return args
-end
-
-local function ratelimit_cb(task)
- if rspamd_lua_utils.is_rspamc_or_controller(task) then return end
- local function rl_redis_cb(err, data)
- if err then
- rspamd_logger.infox(task, 'got error while setting limit: %1', err)
- end
- if not data then return end
- if data[1] == 1 then
- rspamd_logger.infox(task,
- 'ratelimit "%s" exceeded',
- data[2])
- task:set_pre_result('soft reject',
- message_func(task, data[2]))
- end
- end
- local function rl_symbol_redis_cb(err, data)
- if err then
- rspamd_logger.infox(task, 'got error while setting limit: %1', err)
- end
- if not data then return end
- for i, b in ipairs(data) do
- task:insert_result(ratelimit_symbol, b[2], string.format('%s:%s:%s', i, b[1], b[2]))
- end
- end
- local redis_cb = rl_redis_cb
- if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end
- local buckets = get_buckets(task)
- if not buckets then return end
- local args = {redis_script_sha, #buckets}
- for _, bucket in ipairs(buckets) do
- table.insert(args, bucket[2])
- end
- for _, bucket in ipairs(buckets) do
- if use_ip_score then
- local asn_score,total_asn,
- country_score,total_country,
- ipnet_score,total_ipnet,
- ip_score, total_ip = task:get_mempool():get_variable('ip_score',
- 'double,double,double,double,double,double,double,double')
- local key_keywords = rspamd_str_split(bucket[2], '_')
- local has_asn, has_ip = false, false
- for _, v in ipairs(key_keywords) do
- if v == "asn" then has_asn = true end
- if v == "ip" then has_ip = true end
- if has_ip and has_asn then break end
- end
- if has_asn and not has_ip then
- bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
- elseif has_ip then
- if total_ip and total_ip > ip_score_lower_bound then
- bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
- elseif total_ipnet and total_ipnet > ip_score_lower_bound then
- bucket[1][2] = resize_element(ipnet_score, total_ipnet, bucket[1][2])
- elseif total_asn and total_asn > ip_score_lower_bound then
- bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2])
- elseif total_country and total_country > ip_score_lower_bound then
- bucket[1][2] = resize_element(country_score, total_country, bucket[1][2])
- else
- bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2])
- end
- end
- end
- table.insert(args, bucket[1][1])
- table.insert(args, bucket[1][2])
- end
- table.insert(args, rspamd_util.get_time())
- table.insert(args, task:get_queue_id() or task:get_uid())
- local ret = rspamd_redis_make_request(task,
- redis_params, -- connect params
- nil, -- hash key
- true, -- is write
- redis_cb, --callback
- 'evalsha', -- command
- args -- arguments
- )
- if not ret then
- rspamd_logger.errx(task, 'got error connecting to redis')
+ if redis_keys[1] then
+ return collect_redis_keys()
+ else
+ return process_buckets(task, args)
end
end