diff options
Diffstat (limited to 'src/plugins/lua/ratelimit.lua')
-rw-r--r-- | src/plugins/lua/ratelimit.lua | 178 |
1 files changed, 104 insertions, 74 deletions
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index b225d0650..9f8292d6b 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -39,7 +39,7 @@ local redis_params -- Senders that are considered as bounce local settings = { bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' }, --- Do not check ratelimits for these recipients + -- Do not check ratelimits for these recipients whitelisted_rcpts = { 'postmaster', 'mailer-daemon' }, prefix = 'RL', ham_factor_rate = 1.01, @@ -54,11 +54,9 @@ local settings = { prefilter = true, } - local bucket_check_script = "ratelimit_check.lua" local bucket_check_id - local bucket_update_script = "ratelimit_update.lua" local bucket_update_id @@ -70,7 +68,6 @@ local message_func = function(_, limit_type, _, _, _) return string.format('Ratelimit "%s" exceeded', limit_type) end - local function load_scripts(_, _) bucket_check_id = lua_redis.load_redis_script_from_file(bucket_check_script, redis_params) bucket_update_id = lua_redis.load_redis_script_from_file(bucket_update_script, redis_params) @@ -106,27 +103,28 @@ local function parse_string_limit(lim, no_error) if not limit_parser then local digit = lpeg.R("09") limit_parser = {} - limit_parser.integer = - (lpeg.S("+-") ^ -1) * - (digit ^ 1) - limit_parser.fractional = - (lpeg.P(".") ) * - (digit ^ 1) - limit_parser.number = - (limit_parser.integer * - (limit_parser.fractional ^ -1)) + - (lpeg.S("+-") * limit_parser.fractional) + limit_parser.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + limit_parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + limit_parser.number = (limit_parser.integer * + (limit_parser.fractional ^ -1)) + + (lpeg.S("+-") * limit_parser.fractional) limit_parser.time = lpeg.Cf(lpeg.Cc(1) * - (limit_parser.number / tonumber) * - ((lpeg.S("smhd") / parse_time_suffix) ^ -1), - function (acc, val) return acc * val end) + (limit_parser.number / tonumber) * + ((lpeg.S("smhd") / parse_time_suffix) ^ -1), + function(acc, val) + return acc * val + end) limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) * - (limit_parser.number / tonumber) * - ((lpeg.S("kmg") / parse_num_suffix) ^ -1), - function (acc, val) return acc * val end) + (limit_parser.number / tonumber) * + ((lpeg.S("kmg") / parse_num_suffix) ^ -1), + function(acc, val) + return acc * val + end) limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number * - (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * - limit_parser.time) + (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * + limit_parser.time) end local t = lpeg.match(limit_parser.limit, lim) @@ -142,7 +140,7 @@ local function parse_string_limit(lim, no_error) end local function str_to_rate(str) - local divider,divisor = parse_string_limit(str, false) + local divider, divisor = parse_string_limit(str, false) if not divisor then rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str) @@ -153,7 +151,7 @@ local function str_to_rate(str) return divisor / divider end -local bucket_schema = ts.shape{ +local bucket_schema = ts.shape { burst = ts.number + ts.string / lua_util.dehumanize_number, rate = ts.number + ts.string / str_to_rate, skip_recipients = ts.boolean:is_optional(), @@ -184,7 +182,7 @@ local function parse_limit(name, data) return nil else - local parsed_bucket,err = bucket_schema:transform(data) + local parsed_bucket, err = bucket_schema:transform(data) if not parsed_bucket or err then rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s', @@ -210,21 +208,27 @@ end --- Check whether this addr is bounce local function check_bounce(from) - return fun.any(function(b) return b == from end, settings.bounce_senders) + return fun.any(function(b) + return b == from + end, settings.bounce_senders) end local keywords = { ['ip'] = { ['get_value'] = function(task) local ip = task:get_ip() - if ip and ip:is_valid() then return tostring(ip) end + if ip and ip:is_valid() then + return tostring(ip) + end return nil end, }, ['rip'] = { ['get_value'] = function(task) local ip = task:get_ip() - if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end + if ip and ip:is_valid() and not ip:is_local() then + return tostring(ip) + end return nil end, }, @@ -243,7 +247,11 @@ local keywords = { if not ((from or E)[1] or E).user then return '_' end - if check_bounce(from[1]['user']) then return '_' else return nil end + if check_bounce(from[1]['user']) then + return '_' + else + return nil + end end, }, ['asn'] = { @@ -281,7 +289,7 @@ local keywords = { local parts = task:get_parts() or E local digests = {} - for _,p in ipairs(parts) do + for _, p in ipairs(parts) do if p:get_filename() then table.insert(digests, p:get_digest()) end @@ -299,7 +307,7 @@ local keywords = { local parts = task:get_parts() or E local files = {} - for _,p in ipairs(parts) do + for _, p in ipairs(parts) do local fname = p:get_filename() if fname then table.insert(files, fname) @@ -316,7 +324,7 @@ local keywords = { } local function gen_rate_key(task, rtype, bucket) - local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))} + local key_t = { tostring(lua_util.round(100000.0 / bucket.burst)) } local key_keywords = lua_util.str_split(rtype, '_') local have_user = false @@ -326,9 +334,15 @@ local function gen_rate_key(task, rtype, bucket) if keywords[v] and type(keywords[v]['get_value']) == 'function' then ret = keywords[v]['get_value'](task) end - if not ret then return nil end - if v == 'user' then have_user = true end - if type(ret) ~= 'string' then ret = tostring(ret) end + if not ret then + return nil + end + if v == 'user' then + have_user = true + end + if type(ret) ~= 'string' then + ret = tostring(ret) + end table.insert(key_t, ret) end @@ -341,7 +355,9 @@ end local function make_prefix(redis_key, name, bucket) local hash_len = 24 - if hash_len > #redis_key then hash_len = #redis_key end + if hash_len > #redis_key then + hash_len = #redis_key + end local hash = settings.prefix .. string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len) -- Fill defaults @@ -367,7 +383,7 @@ end local function limit_to_prefixes(task, k, v, prefixes) local n = 0 - for _,bucket in ipairs(v.buckets) do + for _, bucket in ipairs(v.buckets) do if v.selector then local selectors = lua_selectors.process_selectors(task, v.selector) if selectors then @@ -403,7 +419,9 @@ end local function ratelimit_cb(task) if not settings.allow_local and - rspamd_lua_utils.is_rspamc_or_controller(task) then return end + rspamd_lua_utils.is_rspamc_or_controller(task) then + return + end -- Get initial task data local ip = task:get_from_ip() @@ -419,10 +437,14 @@ local function ratelimit_cb(task) local rcpts_user = {} if rcpts then fun.each(function(r) - fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'}) + fun.each(function(type) + table.insert(rcpts_user, r[type]) + end, { 'user', 'addr' }) end, rcpts) - if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then + if fun.any(function(r) + return settings.whitelisted_rcpts:get_key(r) + end, rcpts_user) then rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient') return end @@ -439,7 +461,7 @@ local function ratelimit_cb(task) local prefixes = {} local nprefixes = 0 - for k,v in pairs(settings.limits) do + for k, v in pairs(settings.limits) do nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes) end @@ -504,9 +526,11 @@ local function ratelimit_cb(task) end -- Don't do anything if pre-result has been already set - if task:has_pre_result() then return end + if task:has_pre_result() then + return + end - local _,nrcpt = task:has_recipients('smtp') + local _, nrcpt = task:has_recipients('smtp') if not nrcpt or nrcpt <= 0 then nrcpt = 1 end @@ -518,19 +542,21 @@ local function ratelimit_cb(task) now = lua_util.round(now * 1000.0) -- Get milliseconds -- Now call check script for all defined prefixes - for pr,value in pairs(prefixes) do + for pr, value in pairs(prefixes) do local bucket = value.bucket local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms local bincr = nrcpt - if bucket.skip_recipients then bincr = 1 end + if bucket.skip_recipients then + bincr = 1 + end lua_util.debugm(N, task, "check limit %s:%s -> %s (%s/%s)", value.name, pr, value.hash, bucket.burst, bucket.rate) lua_redis.exec_redis_script(bucket_check_id, - {key = value.hash, task = task, is_write = true}, + { key = value.hash, task = task, is_write = true }, gen_check_cb(pr, bucket, value.name, value.hash), - {value.hash, tostring(now), tostring(rate), tostring(bucket.burst), - tostring(settings.expire), tostring(bincr)}) + { value.hash, tostring(now), tostring(rate), tostring(bucket.burst), + tostring(settings.expire), tostring(bincr) }) end end end @@ -553,18 +579,20 @@ local function maybe_cleanup_pending(task) lua_util.debugm(N, task, 'cleaned pending bucked for %s: %s', k, data) end end - local _,nrcpt = task:has_recipients('smtp') + local _, nrcpt = task:has_recipients('smtp') if not nrcpt or nrcpt <= 0 then nrcpt = 1 end local bincr = nrcpt - if bucket.skip_recipients then bincr = 1 end + if bucket.skip_recipients then + bincr = 1 + end local now = task:get_timeval(true) now = lua_util.round(now * 1000.0) -- Get milliseconds lua_redis.exec_redis_script(bucket_cleanup_id, - {key = v.hash, task = task, is_write = true}, + { key = v.hash, task = task, is_write = true }, cleanup_cb, - {v.hash, tostring(now), tostring(settings.expire), tostring(bincr)}) + { v.hash, tostring(now), tostring(settings.expire), tostring(bincr) }) end end end @@ -590,7 +618,7 @@ local function ratelimit_update_cb(task) end local verdict = lua_verdict.get_specific_verdict(N, task) - local _,nrcpt = task:has_recipients('smtp') + local _, nrcpt = task:has_recipients('smtp') if not nrcpt or nrcpt <= 0 then nrcpt = 1 end @@ -624,14 +652,16 @@ local function ratelimit_update_cb(task) end local bincr = nrcpt - if bucket.skip_recipients then bincr = 1 end + if bucket.skip_recipients then + bincr = 1 + end lua_redis.exec_redis_script(bucket_update_id, - {key = v.hash, task = task, is_write = true}, + { key = v.hash, task = task, is_write = true }, update_bucket_cb, - {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst), - tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult), - tostring(settings.expire), tostring(bincr)}) + { v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst), + tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult), + tostring(settings.expire), tostring(bincr) }) end end end @@ -653,7 +683,7 @@ if opts then if type(lim) == 'table' and lim.bucket then if lim.bucket[1] then - for _,bucket in ipairs(lim.bucket) do + for _, bucket in ipairs(lim.bucket) do local b = parse_limit(t, bucket) if not b then @@ -673,7 +703,7 @@ if opts then return end - buckets = {bucket} + buckets = { bucket } end settings.limits[t] = { @@ -696,7 +726,7 @@ if opts then buckets = parse_limit(t, lim) if buckets then settings.limits[t] = { - buckets = {buckets} + buckets = { buckets } } end end @@ -706,7 +736,7 @@ if opts then -- Display what's enabled fun.each(function(s) rspamd_logger.infox(rspamd_config, 'enabled ratelimit: %s', s) - end, fun.map(function(n,d) + end, fun.map(function(n, d) return string.format('%s [%s]', n, table.concat(fun.totable(fun.map(function(v) return string.format('symbol: %s, %s msgs burst, %s msgs/sec rate', @@ -723,14 +753,14 @@ if opts then if type(wrcpts) == 'string' then if string.find(wrcpts, ',') then settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( - lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts') + lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts') else settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', - 'Ratelimit whitelisted rcpts') + 'Ratelimit whitelisted rcpts') end elseif type(opts['whitelisted_rcpts']) == 'table' then settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', - 'Ratelimit whitelisted rcpts') + 'Ratelimit whitelisted rcpts') else -- Stupid default... settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( @@ -739,12 +769,12 @@ if opts then if opts['whitelisted_ip'] then settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix', - 'Ratelimit whitelist ip map') + 'Ratelimit whitelist ip map') end if opts['whitelisted_user'] then settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set', - 'Ratelimit whitelist user map') + 'Ratelimit whitelist user map') end settings.custom_keywords = {} @@ -754,7 +784,7 @@ if opts then if ret then opts['custom_keywords'] = {} if type(res_or_err) == 'table' then - for k,hdl in pairs(res_or_err) do + for k, hdl in pairs(res_or_err) do settings['custom_keywords'][k] = hdl end elseif type(res_or_err) == 'function' then @@ -783,7 +813,7 @@ if opts then priority = lua_util.symbols_priorities.medium, callback = ratelimit_cb, flags = 'empty,nostat', - augmentations = {string.format("timeout=%f", redis_params.timeout or 0.0)}, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, } local id = rspamd_config:register_symbol(s) @@ -792,9 +822,9 @@ if opts then -- Display what's enabled fun.each(function(set, lim) if type(lim.buckets) == 'table' then - for _,b in ipairs(lim.buckets) do + for _, b in ipairs(lim.buckets) do if b.symbol then - rspamd_config:register_symbol{ + rspamd_config:register_symbol { type = 'virtual', name = b.symbol, score = 0.0, @@ -806,7 +836,7 @@ if opts then end, settings.limits) if settings.info_symbol then - rspamd_config:register_symbol{ + rspamd_config:register_symbol { type = 'virtual', name = settings.info_symbol, score = 0.0, @@ -814,7 +844,7 @@ if opts then } end if settings.symbol then - rspamd_config:register_symbol{ + rspamd_config:register_symbol { type = 'virtual', name = settings.symbol, score = 0.0, -- Might be overridden if needed @@ -827,7 +857,7 @@ if opts then name = 'RATELIMIT_UPDATE', flags = 'explicit_disable,ignore_passthrough', callback = ratelimit_update_cb, - augmentations = {string.format("timeout=%f", redis_params.timeout or 0.0)}, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, } end end |