diff options
-rw-r--r-- | lualib/redis_scripts/ratelimit_cleanup_pending.lua | 29 | ||||
-rw-r--r-- | src/plugins/lua/ratelimit.lua | 53 |
2 files changed, 79 insertions, 3 deletions
diff --git a/lualib/redis_scripts/ratelimit_cleanup_pending.lua b/lualib/redis_scripts/ratelimit_cleanup_pending.lua new file mode 100644 index 000000000..f51599b09 --- /dev/null +++ b/lualib/redis_scripts/ratelimit_cleanup_pending.lua @@ -0,0 +1,29 @@ +-- This script cleans up the pending requests in Redis. + +-- KEYS: Input parameters +-- KEYS[1] - prefix: The Redis key prefix used to store the bucket information. +-- KEYS[2] - now: The current time in milliseconds. +-- KEYS[3] - expire: The expiration time for the Redis key storing the bucket information, in seconds. +-- KEYS[4] - number_of_recipients: The number of requests to be allowed (or the increase rate). + +-- 1. Retrieve the last hit time and initialize variables +local prefix = KEYS[1] +local last = redis.call('HGET', prefix, 'l') +local nrcpt = tonumber(KEYS[4]) +if not last then + -- No bucket, no cleanup + return 0 +end + + +-- 2. Update the pending values based on the number of recipients (requests) +local pending = redis.call('HGET', prefix, 'p') +pending = tonumber(pending or '0') +if pending < nrcpt then pending = 0 else pending = pending - nrcpt end + +-- 3. Set the updated values back to Redis and update the expiration time for the bucket +redis.call('HMSET', prefix, tostring(pending), 'l', KEYS[2]) +redis.call('EXPIRE', prefix, KEYS[3]) + +-- 4. Return the updated pending value +return pending
\ No newline at end of file diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index 520efc99e..e2e4e6887 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -62,6 +62,9 @@ local bucket_check_id local bucket_update_script = "ratelimit_update.lua" local bucket_update_id +local bucket_cleanup_script = "ratelimit_cleanup_pending.lua" +local bucket_cleanup_id + -- message_func(task, limit_type, prefix, bucket, limit_key) local message_func = function(_, limit_type, _, _, _) return string.format('Ratelimit "%s" exceeded', limit_type) @@ -71,6 +74,7 @@ end local function load_scripts(_, _) bucket_check_id = lua_redis.load_redis_script_from_file(bucket_check_script, redis_params) bucket_update_id = lua_redis.load_redis_script_from_file(bucket_update_script, redis_params) + bucket_cleanup_id = lua_redis.load_redis_script_from_file(bucket_cleanup_script, redis_params) end local limit_parser @@ -464,6 +468,7 @@ local function ratelimit_cb(task) prefix, bucket.burst, bucket.rate, data[2], data[3], data[4], data[5]) + task:cache_set('ratelimit_bucket_touched', true) if data[1] == 1 then -- set symbol only and do NOT soft reject if bucket.symbol then @@ -530,15 +535,57 @@ local function ratelimit_cb(task) end end + +-- This function is used to clean up pending bucket when +-- the task is somehow being skipped (e.g. greylisting/ratelimit/whatever) +-- but the ratelimit buckets for this task are touched (e.g. pending has been increased) +-- See https://github.com/rspamd/rspamd/issues/4467 for more context +local function maybe_cleanup_pending(task) + if task:cache_get('ratelimit_bucket_touched') then + local prefixes = task:cache_get('ratelimit_prefixes') + if prefixes then + for k, v in pairs(prefixes) or E do + local bucket = v.bucket + local function cleanup_cb(err, data) + if err then + rspamd_logger.errx('cannot cleanup limit %s: %s %s', k, err, data) + else + lua_util.debugm(N, task, 'cleaned pending bucked for %s: %s', k, data) + end + end + local _,nrcpt = task:has_recipients('smtp') + if not nrcpt or nrcpt <= 0 then + nrcpt = 1 + end + local bincr = nrcpt + if bucket.skip_recipients then bincr = 1 end + local now = task:get_timeval(true) + now = lua_util.round(now * 1000.0) -- Get milliseconds + lua_redis.exec_redis_script(bucket_cleanup_id, + {key = v.hash, task = task, is_write = true}, + cleanup_cb, + {v.hash, tostring(now), tostring(settings.expire), tostring(bincr)}) + end + end + end +end + local function ratelimit_update_cb(task) - if task:has_flag('skip') then return end - if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end + if task:has_flag('skip') then + maybe_cleanup_pending(task) + return + end + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + maybe_cleanup_pending(task) + end + local prefixes = task:cache_get('ratelimit_prefixes') if prefixes then if task:has_pre_result() then -- Already rate limited/greylisted, do nothing lua_util.debugm(N, task, 'pre-action has been set, do not update') + maybe_cleanup_pending(task) return end @@ -563,7 +610,7 @@ local function ratelimit_update_cb(task) data[1], data[2], data[3]) end end - local now = rspamd_util.get_time() + local now = task:get_timeval(true) now = lua_util.round(now * 1000.0) -- Get milliseconds local mult_burst = 1.0 local mult_rate = 1.0 |