diff options
Diffstat (limited to 'src/plugins/lua/ratelimit.lua')
-rw-r--r-- | src/plugins/lua/ratelimit.lua | 143 |
1 files changed, 5 insertions, 138 deletions
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index f3331e850..168d8d63a 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -29,8 +29,7 @@ local lua_util = require "lua_util" local lua_verdict = require "lua_verdict" local rspamd_hash = require "rspamd_cryptobox_hash" local lua_selectors = require "lua_selectors" -local ts = require("tableshape").types - +local ratelimit_common = require "plugins/ratelimit" -- A plugin that implements ratelimits using redis local E = {} @@ -76,138 +75,6 @@ local function load_scripts(_, _) bucket_cleanup_id = lua_redis.load_redis_script_from_file(bucket_cleanup_script, redis_params) end -local limit_parser -local function parse_string_limit(lim, no_error) - local function parse_time_suffix(s) - if s == 's' then - return 1 - elseif s == 'm' then - return 60 - elseif s == 'h' then - return 3600 - elseif s == 'd' then - return 86400 - end - end - local function parse_num_suffix(s) - if s == '' then - return 1 - elseif s == 'k' then - return 1000 - elseif s == 'm' then - return 1000000 - elseif s == 'g' then - return 1000000000 - end - end - local lpeg = require "lpeg" - - if not limit_parser then - local digit = lpeg.R("09") - limit_parser = {} - limit_parser.integer = (lpeg.S("+-") ^ -1) * - (digit ^ 1) - limit_parser.fractional = (lpeg.P(".")) * - (digit ^ 1) - limit_parser.number = (limit_parser.integer * - (limit_parser.fractional ^ -1)) + - (lpeg.S("+-") * limit_parser.fractional) - limit_parser.time = lpeg.Cf(lpeg.Cc(1) * - (limit_parser.number / tonumber) * - ((lpeg.S("smhd") / parse_time_suffix) ^ -1), - function(acc, val) - return acc * val - end) - limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) * - (limit_parser.number / tonumber) * - ((lpeg.S("kmg") / parse_num_suffix) ^ -1), - function(acc, val) - return acc * val - end) - limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number * - (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * - limit_parser.time) - end - local t = lpeg.match(limit_parser.limit, lim) - - if t and t[1] and t[2] and t[2] ~= 0 then - return t[2], t[1] - end - - if not no_error then - rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim) - end - - return nil -end - -local function str_to_rate(str) - local divider, divisor = parse_string_limit(str, false) - - if not divisor then - rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str) - - return nil - end - - return divisor / divider -end - -local bucket_schema = ts.shape { - burst = ts.number + ts.string / lua_util.dehumanize_number, - rate = ts.number + ts.string / str_to_rate, - skip_recipients = ts.boolean:is_optional(), - symbol = ts.string:is_optional(), - message = ts.string:is_optional(), - skip_soft_reject = ts.boolean:is_optional(), -} - -local function parse_limit(name, data) - if type(data) == 'table' then - -- 2 cases here: - -- * old limit in format [burst, rate] - -- * vector of strings in Andrew's string format (removed from 1.8.2) - -- * proper bucket table - if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then - -- Old style ratelimit - rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name) - if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then - return { - burst = data[1], - rate = data[2] - } - elseif data[1] ~= 0 then - rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name) - else - rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name) - end - - return nil - else - local parsed_bucket, err = bucket_schema:transform(data) - - if not parsed_bucket or err then - rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s', - name, err, data) - else - return parsed_bucket - end - end - elseif type(data) == 'string' then - local rep_rate, burst = parse_string_limit(data) - rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s', - name, data) - if rep_rate and burst then - return { - burst = burst, - rate = burst / rep_rate -- reciprocal - } - end - end - - return nil -end - --- Check whether this addr is bounce local function check_bounce(from) return fun.any(function(b) @@ -490,7 +357,7 @@ local function ratelimit_cb(task) local ret, redis_key, bd = pcall(hdl, task) if ret then - local bucket = parse_limit(k, bd) + local bucket = ratelimit_common.parse_limit(k, bd) if bucket then prefixes[redis_key] = make_prefix(redis_key, k, bucket) end @@ -718,7 +585,7 @@ if opts then if lim.bucket[1] then for _, bucket in ipairs(lim.bucket) do - local b = parse_limit(t, bucket) + local b = ratelimit_common.parse_limit(t, bucket) if not b then rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"', @@ -729,7 +596,7 @@ if opts then table.insert(buckets, b) end else - local bucket = parse_limit(t, lim.bucket) + local bucket = ratelimit_common.parse_limit(t, lim.bucket) if not bucket then rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"', @@ -757,7 +624,7 @@ if opts then end else rspamd_logger.warnx(rspamd_config, 'old syntax for ratelimits: %s', lim) - buckets = parse_limit(t, lim) + buckets = ratelimit_common.parse_limit(t, lim) if buckets then settings.limits[t] = { buckets = { buckets } |