summaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/ratelimit.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/ratelimit.lua')
-rw-r--r--src/plugins/lua/ratelimit.lua178
1 files changed, 104 insertions, 74 deletions
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
index b225d0650..9f8292d6b 100644
--- a/src/plugins/lua/ratelimit.lua
+++ b/src/plugins/lua/ratelimit.lua
@@ -39,7 +39,7 @@ local redis_params
-- Senders that are considered as bounce
local settings = {
bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' },
--- Do not check ratelimits for these recipients
+ -- Do not check ratelimits for these recipients
whitelisted_rcpts = { 'postmaster', 'mailer-daemon' },
prefix = 'RL',
ham_factor_rate = 1.01,
@@ -54,11 +54,9 @@ local settings = {
prefilter = true,
}
-
local bucket_check_script = "ratelimit_check.lua"
local bucket_check_id
-
local bucket_update_script = "ratelimit_update.lua"
local bucket_update_id
@@ -70,7 +68,6 @@ local message_func = function(_, limit_type, _, _, _)
return string.format('Ratelimit "%s" exceeded', limit_type)
end
-
local function load_scripts(_, _)
bucket_check_id = lua_redis.load_redis_script_from_file(bucket_check_script, redis_params)
bucket_update_id = lua_redis.load_redis_script_from_file(bucket_update_script, redis_params)
@@ -106,27 +103,28 @@ local function parse_string_limit(lim, no_error)
if not limit_parser then
local digit = lpeg.R("09")
limit_parser = {}
- limit_parser.integer =
- (lpeg.S("+-") ^ -1) *
- (digit ^ 1)
- limit_parser.fractional =
- (lpeg.P(".") ) *
- (digit ^ 1)
- limit_parser.number =
- (limit_parser.integer *
- (limit_parser.fractional ^ -1)) +
- (lpeg.S("+-") * limit_parser.fractional)
+ limit_parser.integer = (lpeg.S("+-") ^ -1) *
+ (digit ^ 1)
+ limit_parser.fractional = (lpeg.P(".")) *
+ (digit ^ 1)
+ limit_parser.number = (limit_parser.integer *
+ (limit_parser.fractional ^ -1)) +
+ (lpeg.S("+-") * limit_parser.fractional)
limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
- (limit_parser.number / tonumber) *
- ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
- function (acc, val) return acc * val end)
+ (limit_parser.number / tonumber) *
+ ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
+ function(acc, val)
+ return acc * val
+ end)
limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
- (limit_parser.number / tonumber) *
- ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
- function (acc, val) return acc * val end)
+ (limit_parser.number / tonumber) *
+ ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
+ function(acc, val)
+ return acc * val
+ end)
limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
- (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
- limit_parser.time)
+ (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
+ limit_parser.time)
end
local t = lpeg.match(limit_parser.limit, lim)
@@ -142,7 +140,7 @@ local function parse_string_limit(lim, no_error)
end
local function str_to_rate(str)
- local divider,divisor = parse_string_limit(str, false)
+ local divider, divisor = parse_string_limit(str, false)
if not divisor then
rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str)
@@ -153,7 +151,7 @@ local function str_to_rate(str)
return divisor / divider
end
-local bucket_schema = ts.shape{
+local bucket_schema = ts.shape {
burst = ts.number + ts.string / lua_util.dehumanize_number,
rate = ts.number + ts.string / str_to_rate,
skip_recipients = ts.boolean:is_optional(),
@@ -184,7 +182,7 @@ local function parse_limit(name, data)
return nil
else
- local parsed_bucket,err = bucket_schema:transform(data)
+ local parsed_bucket, err = bucket_schema:transform(data)
if not parsed_bucket or err then
rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s',
@@ -210,21 +208,27 @@ end
--- Check whether this addr is bounce
local function check_bounce(from)
- return fun.any(function(b) return b == from end, settings.bounce_senders)
+ return fun.any(function(b)
+ return b == from
+ end, settings.bounce_senders)
end
local keywords = {
['ip'] = {
['get_value'] = function(task)
local ip = task:get_ip()
- if ip and ip:is_valid() then return tostring(ip) end
+ if ip and ip:is_valid() then
+ return tostring(ip)
+ end
return nil
end,
},
['rip'] = {
['get_value'] = function(task)
local ip = task:get_ip()
- if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end
+ if ip and ip:is_valid() and not ip:is_local() then
+ return tostring(ip)
+ end
return nil
end,
},
@@ -243,7 +247,11 @@ local keywords = {
if not ((from or E)[1] or E).user then
return '_'
end
- if check_bounce(from[1]['user']) then return '_' else return nil end
+ if check_bounce(from[1]['user']) then
+ return '_'
+ else
+ return nil
+ end
end,
},
['asn'] = {
@@ -281,7 +289,7 @@ local keywords = {
local parts = task:get_parts() or E
local digests = {}
- for _,p in ipairs(parts) do
+ for _, p in ipairs(parts) do
if p:get_filename() then
table.insert(digests, p:get_digest())
end
@@ -299,7 +307,7 @@ local keywords = {
local parts = task:get_parts() or E
local files = {}
- for _,p in ipairs(parts) do
+ for _, p in ipairs(parts) do
local fname = p:get_filename()
if fname then
table.insert(files, fname)
@@ -316,7 +324,7 @@ local keywords = {
}
local function gen_rate_key(task, rtype, bucket)
- local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))}
+ local key_t = { tostring(lua_util.round(100000.0 / bucket.burst)) }
local key_keywords = lua_util.str_split(rtype, '_')
local have_user = false
@@ -326,9 +334,15 @@ local function gen_rate_key(task, rtype, bucket)
if keywords[v] and type(keywords[v]['get_value']) == 'function' then
ret = keywords[v]['get_value'](task)
end
- if not ret then return nil end
- if v == 'user' then have_user = true end
- if type(ret) ~= 'string' then ret = tostring(ret) end
+ if not ret then
+ return nil
+ end
+ if v == 'user' then
+ have_user = true
+ end
+ if type(ret) ~= 'string' then
+ ret = tostring(ret)
+ end
table.insert(key_t, ret)
end
@@ -341,7 +355,9 @@ end
local function make_prefix(redis_key, name, bucket)
local hash_len = 24
- if hash_len > #redis_key then hash_len = #redis_key end
+ if hash_len > #redis_key then
+ hash_len = #redis_key
+ end
local hash = settings.prefix ..
string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)
-- Fill defaults
@@ -367,7 +383,7 @@ end
local function limit_to_prefixes(task, k, v, prefixes)
local n = 0
- for _,bucket in ipairs(v.buckets) do
+ for _, bucket in ipairs(v.buckets) do
if v.selector then
local selectors = lua_selectors.process_selectors(task, v.selector)
if selectors then
@@ -403,7 +419,9 @@ end
local function ratelimit_cb(task)
if not settings.allow_local and
- rspamd_lua_utils.is_rspamc_or_controller(task) then return end
+ rspamd_lua_utils.is_rspamc_or_controller(task) then
+ return
+ end
-- Get initial task data
local ip = task:get_from_ip()
@@ -419,10 +437,14 @@ local function ratelimit_cb(task)
local rcpts_user = {}
if rcpts then
fun.each(function(r)
- fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'})
+ fun.each(function(type)
+ table.insert(rcpts_user, r[type])
+ end, { 'user', 'addr' })
end, rcpts)
- if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then
+ if fun.any(function(r)
+ return settings.whitelisted_rcpts:get_key(r)
+ end, rcpts_user) then
rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')
return
end
@@ -439,7 +461,7 @@ local function ratelimit_cb(task)
local prefixes = {}
local nprefixes = 0
- for k,v in pairs(settings.limits) do
+ for k, v in pairs(settings.limits) do
nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)
end
@@ -504,9 +526,11 @@ local function ratelimit_cb(task)
end
-- Don't do anything if pre-result has been already set
- if task:has_pre_result() then return end
+ if task:has_pre_result() then
+ return
+ end
- local _,nrcpt = task:has_recipients('smtp')
+ local _, nrcpt = task:has_recipients('smtp')
if not nrcpt or nrcpt <= 0 then
nrcpt = 1
end
@@ -518,19 +542,21 @@ local function ratelimit_cb(task)
now = lua_util.round(now * 1000.0) -- Get milliseconds
-- Now call check script for all defined prefixes
- for pr,value in pairs(prefixes) do
+ for pr, value in pairs(prefixes) do
local bucket = value.bucket
local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms
local bincr = nrcpt
- if bucket.skip_recipients then bincr = 1 end
+ if bucket.skip_recipients then
+ bincr = 1
+ end
lua_util.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",
value.name, pr, value.hash, bucket.burst, bucket.rate)
lua_redis.exec_redis_script(bucket_check_id,
- {key = value.hash, task = task, is_write = true},
+ { key = value.hash, task = task, is_write = true },
gen_check_cb(pr, bucket, value.name, value.hash),
- {value.hash, tostring(now), tostring(rate), tostring(bucket.burst),
- tostring(settings.expire), tostring(bincr)})
+ { value.hash, tostring(now), tostring(rate), tostring(bucket.burst),
+ tostring(settings.expire), tostring(bincr) })
end
end
end
@@ -553,18 +579,20 @@ local function maybe_cleanup_pending(task)
lua_util.debugm(N, task, 'cleaned pending bucked for %s: %s', k, data)
end
end
- local _,nrcpt = task:has_recipients('smtp')
+ local _, nrcpt = task:has_recipients('smtp')
if not nrcpt or nrcpt <= 0 then
nrcpt = 1
end
local bincr = nrcpt
- if bucket.skip_recipients then bincr = 1 end
+ if bucket.skip_recipients then
+ bincr = 1
+ end
local now = task:get_timeval(true)
now = lua_util.round(now * 1000.0) -- Get milliseconds
lua_redis.exec_redis_script(bucket_cleanup_id,
- {key = v.hash, task = task, is_write = true},
+ { key = v.hash, task = task, is_write = true },
cleanup_cb,
- {v.hash, tostring(now), tostring(settings.expire), tostring(bincr)})
+ { v.hash, tostring(now), tostring(settings.expire), tostring(bincr) })
end
end
end
@@ -590,7 +618,7 @@ local function ratelimit_update_cb(task)
end
local verdict = lua_verdict.get_specific_verdict(N, task)
- local _,nrcpt = task:has_recipients('smtp')
+ local _, nrcpt = task:has_recipients('smtp')
if not nrcpt or nrcpt <= 0 then
nrcpt = 1
end
@@ -624,14 +652,16 @@ local function ratelimit_update_cb(task)
end
local bincr = nrcpt
- if bucket.skip_recipients then bincr = 1 end
+ if bucket.skip_recipients then
+ bincr = 1
+ end
lua_redis.exec_redis_script(bucket_update_id,
- {key = v.hash, task = task, is_write = true},
+ { key = v.hash, task = task, is_write = true },
update_bucket_cb,
- {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),
- tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),
- tostring(settings.expire), tostring(bincr)})
+ { v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),
+ tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),
+ tostring(settings.expire), tostring(bincr) })
end
end
end
@@ -653,7 +683,7 @@ if opts then
if type(lim) == 'table' and lim.bucket then
if lim.bucket[1] then
- for _,bucket in ipairs(lim.bucket) do
+ for _, bucket in ipairs(lim.bucket) do
local b = parse_limit(t, bucket)
if not b then
@@ -673,7 +703,7 @@ if opts then
return
end
- buckets = {bucket}
+ buckets = { bucket }
end
settings.limits[t] = {
@@ -696,7 +726,7 @@ if opts then
buckets = parse_limit(t, lim)
if buckets then
settings.limits[t] = {
- buckets = {buckets}
+ buckets = { buckets }
}
end
end
@@ -706,7 +736,7 @@ if opts then
-- Display what's enabled
fun.each(function(s)
rspamd_logger.infox(rspamd_config, 'enabled ratelimit: %s', s)
- end, fun.map(function(n,d)
+ end, fun.map(function(n, d)
return string.format('%s [%s]', n,
table.concat(fun.totable(fun.map(function(v)
return string.format('symbol: %s, %s msgs burst, %s msgs/sec rate',
@@ -723,14 +753,14 @@ if opts then
if type(wrcpts) == 'string' then
if string.find(wrcpts, ',') then
settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
- lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')
+ lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')
else
settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
- 'Ratelimit whitelisted rcpts')
+ 'Ratelimit whitelisted rcpts')
end
elseif type(opts['whitelisted_rcpts']) == 'table' then
settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
- 'Ratelimit whitelisted rcpts')
+ 'Ratelimit whitelisted rcpts')
else
-- Stupid default...
settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
@@ -739,12 +769,12 @@ if opts then
if opts['whitelisted_ip'] then
settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix',
- 'Ratelimit whitelist ip map')
+ 'Ratelimit whitelist ip map')
end
if opts['whitelisted_user'] then
settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set',
- 'Ratelimit whitelist user map')
+ 'Ratelimit whitelist user map')
end
settings.custom_keywords = {}
@@ -754,7 +784,7 @@ if opts then
if ret then
opts['custom_keywords'] = {}
if type(res_or_err) == 'table' then
- for k,hdl in pairs(res_or_err) do
+ for k, hdl in pairs(res_or_err) do
settings['custom_keywords'][k] = hdl
end
elseif type(res_or_err) == 'function' then
@@ -783,7 +813,7 @@ if opts then
priority = lua_util.symbols_priorities.medium,
callback = ratelimit_cb,
flags = 'empty,nostat',
- augmentations = {string.format("timeout=%f", redis_params.timeout or 0.0)},
+ augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) },
}
local id = rspamd_config:register_symbol(s)
@@ -792,9 +822,9 @@ if opts then
-- Display what's enabled
fun.each(function(set, lim)
if type(lim.buckets) == 'table' then
- for _,b in ipairs(lim.buckets) do
+ for _, b in ipairs(lim.buckets) do
if b.symbol then
- rspamd_config:register_symbol{
+ rspamd_config:register_symbol {
type = 'virtual',
name = b.symbol,
score = 0.0,
@@ -806,7 +836,7 @@ if opts then
end, settings.limits)
if settings.info_symbol then
- rspamd_config:register_symbol{
+ rspamd_config:register_symbol {
type = 'virtual',
name = settings.info_symbol,
score = 0.0,
@@ -814,7 +844,7 @@ if opts then
}
end
if settings.symbol then
- rspamd_config:register_symbol{
+ rspamd_config:register_symbol {
type = 'virtual',
name = settings.symbol,
score = 0.0, -- Might be overridden if needed
@@ -827,7 +857,7 @@ if opts then
name = 'RATELIMIT_UPDATE',
flags = 'explicit_disable,ignore_passthrough',
callback = ratelimit_update_cb,
- augmentations = {string.format("timeout=%f", redis_params.timeout or 0.0)},
+ augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) },
}
end
end