From d7363dfa0f54a0d05739936b607c6b1c8b750026 Mon Sep 17 00:00:00 2001 From: Andrew Lewis Date: Tue, 29 Aug 2017 13:07:13 +0200 Subject: [PATCH] [Feature] Ratelimit: support fetching limits from Redis --- src/plugins/lua/ratelimit.lua | 221 +++++++++++++++++++++------------- 1 file changed, 140 insertions(+), 81 deletions(-) diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index 2516d1844..b05ed0114 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -375,7 +375,86 @@ local function dynamic_rate_key(task, rtype) end end -local function get_buckets(task) +local function process_buckets(task, buckets) + if not buckets then return end + local function rl_redis_cb(err, data) + if err then + rspamd_logger.infox(task, 'got error while setting limit: %1', err) + end + if not data then return end + if data[1] == 1 then + rspamd_logger.infox(task, + 'ratelimit "%s" exceeded', + data[2]) + task:set_pre_result('soft reject', + message_func(task, data[2])) + end + end + local function rl_symbol_redis_cb(err, data) + if err then + rspamd_logger.infox(task, 'got error while setting limit: %1', err) + end + if not data then return end + for i, b in ipairs(data) do + task:insert_result(ratelimit_symbol, b[2], string.format('%s:%s:%s', i, b[1], b[2])) + end + end + local redis_cb = rl_redis_cb + if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end + local args = {redis_script_sha, #buckets} + for _, bucket in ipairs(buckets) do + table.insert(args, bucket[2]) + end + for _, bucket in ipairs(buckets) do + if use_ip_score then + local asn_score,total_asn, + country_score,total_country, + ipnet_score,total_ipnet, + ip_score, total_ip = task:get_mempool():get_variable('ip_score', + 'double,double,double,double,double,double,double,double') + local key_keywords = rspamd_str_split(bucket[2], '_') + local has_asn, has_ip = false, false + for _, v in ipairs(key_keywords) do + if v == "asn" then has_asn = true end + if v == "ip" then has_ip = true end + if has_ip and has_asn then break end + end + if has_asn and not has_ip then + bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2]) + elseif has_ip then + if total_ip and total_ip > ip_score_lower_bound then + bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2]) + elseif total_ipnet and total_ipnet > ip_score_lower_bound then + bucket[1][2] = resize_element(ipnet_score, total_ipnet, bucket[1][2]) + elseif total_asn and total_asn > ip_score_lower_bound then + bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2]) + elseif total_country and total_country > ip_score_lower_bound then + bucket[1][2] = resize_element(country_score, total_country, bucket[1][2]) + else + bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2]) + end + end + end + table.insert(args, bucket[1][1]) + table.insert(args, bucket[1][2]) + end + table.insert(args, rspamd_util.get_time()) + table.insert(args, task:get_queue_id() or task:get_uid()) + local ret = rspamd_redis_make_request(task, + redis_params, -- connect params + nil, -- hash key + true, -- is write + redis_cb, --callback + 'evalsha', -- command + args -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'got error connecting to redis') + end +end + +local function ratelimit_cb(task) + if rspamd_lua_utils.is_rspamc_or_controller(task) then return end local args = {} -- Get initial task data local ip = task:get_from_ip() @@ -412,6 +491,38 @@ local function get_buckets(task) end end + local redis_keys = {} + local redis_keys_rev = {} + local function collect_redis_keys() + local function collect_cb(err, data) + if err then + rspamd_logger.errx(task, 'redis error: %1', err) + else + for i, d in ipairs(data) do + if type(d) == 'string' then + local plim, size = parse_string_limit(d) + if plim then + table.insert(args, {{plim, size}, redis_keys_rev[i]}) + end + end + end + return process_buckets(task, args) + end + end + local requested_keys = rspamd_redis_make_request(task, + redis_params, -- connect params + nil, -- hash key + true, -- is write + collect_cb, --callback + 'MGET', -- command + redis_keys -- arguments + ) + if not requested_keys then + rspamd_logger.errx(task, 'got error connecting to redis') + return process_buckets(task, args) + end + end + local rate_key for k in pairs(settings) do rate_key = dynamic_rate_key(task, k) @@ -426,6 +537,14 @@ local function get_buckets(task) local plim, size = parse_string_limit(r) if plim then table.insert(args, {{plim, size}, rk}) + else + local rkey = string.match(settings[k], 'redis:(.*)') + if rkey then + table.insert(redis_keys, rkey) + redis_keys_rev[#redis_keys] = rk + else + rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k]) + end end end end @@ -439,97 +558,37 @@ local function get_buckets(task) local plim, size = parse_string_limit(r) if plim then table.insert(args, {{plim, size}, rate_key}) + else + local rkey = string.match(settings[k], 'redis:(.*)') + if rkey then + table.insert(redis_keys, rkey) + redis_keys_rev[#redis_keys] = rate_key + else + rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k]) + end end end elseif type(settings[k]) == 'table' then for _, rl in ipairs(settings[k]) do table.insert(args, {{rl[1], rl[2]}, rate_key}) end + elseif type(settings[k]) == 'string' then + local rkey = string.match(settings[k], 'redis:(.*)') + if rkey then + table.insert(redis_keys, rkey) + redis_keys_rev[#redis_keys] = rate_key + else + rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k]) + end end end end end - return args -end - -local function ratelimit_cb(task) - if rspamd_lua_utils.is_rspamc_or_controller(task) then return end - local function rl_redis_cb(err, data) - if err then - rspamd_logger.infox(task, 'got error while setting limit: %1', err) - end - if not data then return end - if data[1] == 1 then - rspamd_logger.infox(task, - 'ratelimit "%s" exceeded', - data[2]) - task:set_pre_result('soft reject', - message_func(task, data[2])) - end - end - local function rl_symbol_redis_cb(err, data) - if err then - rspamd_logger.infox(task, 'got error while setting limit: %1', err) - end - if not data then return end - for i, b in ipairs(data) do - task:insert_result(ratelimit_symbol, b[2], string.format('%s:%s:%s', i, b[1], b[2])) - end - end - local redis_cb = rl_redis_cb - if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end - local buckets = get_buckets(task) - if not buckets then return end - local args = {redis_script_sha, #buckets} - for _, bucket in ipairs(buckets) do - table.insert(args, bucket[2]) - end - for _, bucket in ipairs(buckets) do - if use_ip_score then - local asn_score,total_asn, - country_score,total_country, - ipnet_score,total_ipnet, - ip_score, total_ip = task:get_mempool():get_variable('ip_score', - 'double,double,double,double,double,double,double,double') - local key_keywords = rspamd_str_split(bucket[2], '_') - local has_asn, has_ip = false, false - for _, v in ipairs(key_keywords) do - if v == "asn" then has_asn = true end - if v == "ip" then has_ip = true end - if has_ip and has_asn then break end - end - if has_asn and not has_ip then - bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2]) - elseif has_ip then - if total_ip and total_ip > ip_score_lower_bound then - bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2]) - elseif total_ipnet and total_ipnet > ip_score_lower_bound then - bucket[1][2] = resize_element(ipnet_score, total_ipnet, bucket[1][2]) - elseif total_asn and total_asn > ip_score_lower_bound then - bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2]) - elseif total_country and total_country > ip_score_lower_bound then - bucket[1][2] = resize_element(country_score, total_country, bucket[1][2]) - else - bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2]) - end - end - end - table.insert(args, bucket[1][1]) - table.insert(args, bucket[1][2]) - end - table.insert(args, rspamd_util.get_time()) - table.insert(args, task:get_queue_id() or task:get_uid()) - local ret = rspamd_redis_make_request(task, - redis_params, -- connect params - nil, -- hash key - true, -- is write - redis_cb, --callback - 'evalsha', -- command - args -- arguments - ) - if not ret then - rspamd_logger.errx(task, 'got error connecting to redis') + if redis_keys[1] then + return collect_redis_keys() + else + return process_buckets(task, args) end end -- 2.39.5