--[[ Copyright (c) 2011-2017, Vsevolod Stakhov Copyright (c) 2016-2017, Andrew Lewis Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ]]-- if confighelp then return end -- A plugin that implements ratelimits using redis local E = {} local N = 'ratelimit' local redis_params -- Senders that are considered as bounce local settings = { bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' }, -- Do not check ratelimits for these recipients whitelisted_rcpts = { 'postmaster', 'mailer-daemon' }, prefix = 'RL', ham_factor_rate = 1.01, spam_factor_rate = 0.99, ham_factor_burst = 1.02, spam_factor_burst = 0.98, max_rate_mult = 5, max_bucket_mult = 10, expire = 60 * 60 * 24 * 2, -- 2 days by default limits = {}, allow_local = false, } -- Checks bucket, updating it if needed -- KEYS[1] - prefix to update, e.g. RL__ -- KEYS[2] - current time in milliseconds -- KEYS[3] - bucket leak rate (messages per millisecond) -- KEYS[4] - bucket burst -- KEYS[5] - expire for a bucket -- return 1 if message should be ratelimited and 0 if not -- Redis keys used: -- l - last hit -- b - current burst -- dr - current dynamic rate multiplier (*10000) -- db - current dynamic burst multiplier (*10000) local bucket_check_script = [[ local last = redis.call('HGET', KEYS[1], 'l') local now = tonumber(KEYS[2]) local dynr, dynb = 0, 0 if not last then -- New bucket redis.call('HSET', KEYS[1], 'l', KEYS[2]) redis.call('HSET', KEYS[1], 'b', '0') redis.call('HSET', KEYS[1], 'dr', '10000') redis.call('HSET', KEYS[1], 'db', '10000') redis.call('EXPIRE', KEYS[1], KEYS[5]) return {0, 0, 1, 1} end last = tonumber(last) local burst = tonumber(redis.call('HGET', KEYS[1], 'b')) -- Perform leak if burst > 0 then if last < tonumber(KEYS[2]) then local rate = tonumber(KEYS[3]) dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0 rate = rate * dynr local leaked = ((now - last) * rate) burst = burst - leaked redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked)) end dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0 if burst * dynb > tonumber(KEYS[4]) then return {1, burst, dynr, dynb} end else burst = 0 redis.call('HSET', KEYS[1], 'b', '0') end return {0, burst, tostring(dynr), tostring(dynb)} ]] local bucket_check_id -- Updates a bucket -- KEYS[1] - prefix to update, e.g. RL__ -- KEYS[2] - current time in milliseconds -- KEYS[3] - dynamic rate multiplier -- KEYS[4] - dynamic burst multiplier -- KEYS[5] - max dyn rate (min: 1/x) -- KEYS[6] - max burst rate (min: 1/x) -- KEYS[7] - expire for a bucket -- Redis keys used: -- l - last hit -- b - current burst -- dr - current dynamic rate multiplier -- db - current dynamic burst multiplier local bucket_update_script = [[ local last = redis.call('HGET', KEYS[1], 'l') local now = tonumber(KEYS[2]) if not last then -- New bucket redis.call('HSET', KEYS[1], 'l', KEYS[2]) redis.call('HSET', KEYS[1], 'b', '1') redis.call('HSET', KEYS[1], 'dr', '10000') redis.call('HSET', KEYS[1], 'db', '10000') redis.call('EXPIRE', KEYS[1], KEYS[7]) return {1, 1, 1} end local burst = tonumber(redis.call('HGET', KEYS[1], 'b')) local db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000 local dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000 if dr < tonumber(KEYS[5]) and dr > 1.0 / tonumber(KEYS[5]) then dr = dr * tonumber(KEYS[3]) redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000))) end if db < tonumber(KEYS[6]) and db > 1.0 / tonumber(KEYS[6]) then db = db * tonumber(KEYS[4]) redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000))) end redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1) redis.call('HSET', KEYS[1], 'l', KEYS[2]) redis.call('EXPIRE', KEYS[1], KEYS[7]) return {burst, tostring(dr), tostring(db)} ]] local bucket_update_id -- message_func(task, limit_type, prefix, bucket) local message_func = function(_, limit_type, _, _) return string.format('Ratelimit "%s" exceeded', limit_type) end local rspamd_logger = require "rspamd_logger" local rspamd_util = require "rspamd_util" local rspamd_lua_utils = require "lua_util" local lua_redis = require "lua_redis" local fun = require "fun" local lua_maps = require "lua_maps" local lua_util = require "lua_util" local rspamd_hash = require "rspamd_cryptobox_hash" local function load_scripts(cfg, ev_base) bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params) bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params) end local limit_parser local function parse_string_limit(lim, no_error) local function parse_time_suffix(s) if s == 's' then return 1 elseif s == 'm' then return 60 elseif s == 'h' then return 3600 elseif s == 'd' then return 86400 end end local function parse_num_suffix(s) if s == '' then return 1 elseif s == 'k' then return 1000 elseif s == 'm' then return 1000000 elseif s == 'g' then return 1000000000 end end local lpeg = require "lpeg" if not limit_parser then local digit = lpeg.R("09") limit_parser = {} limit_parser.integer = (lpeg.S("+-") ^ -1) * (digit ^ 1) limit_parser.fractional = (lpeg.P(".") ) * (digit ^ 1) limit_parser.number = (limit_parser.integer * (limit_parser.fractional ^ -1)) + (lpeg.S("+-") * limit_parser.fractional) limit_parser.time = lpeg.Cf(lpeg.Cc(1) * (limit_parser.number / tonumber) * ((lpeg.S("smhd") / parse_time_suffix) ^ -1), function (acc, val) return acc * val end) limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) * (limit_parser.number / tonumber) * ((lpeg.S("kmg") / parse_num_suffix) ^ -1), function (acc, val) return acc * val end) limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number * (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * limit_parser.time) end local t = lpeg.match(limit_parser.limit, lim) if t and t[1] and t[2] and t[2] ~= 0 then return t[2], t[1] end if not no_error then rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim) end return nil end --- Check whether this addr is bounce local function check_bounce(from) return fun.any(function(b) return b == from end, settings.bounce_senders) end local keywords = { ['ip'] = { ['get_value'] = function(task) local ip = task:get_ip() if ip and ip:is_valid() then return tostring(ip) end return nil end, }, ['rip'] = { ['get_value'] = function(task) local ip = task:get_ip() if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end return nil end, }, ['from'] = { ['get_value'] = function(task) local from = task:get_from(0) if ((from or E)[1] or E).addr then return string.lower(from[1]['addr']) end return nil end, }, ['bounce'] = { ['get_value'] = function(task) local from = task:get_from(0) if not ((from or E)[1] or E).user then return '_' end if check_bounce(from[1]['user']) then return '_' else return nil end end, }, ['asn'] = { ['get_value'] = function(task) local asn = task:get_mempool():get_variable('asn') if not asn then return nil else return asn end end, }, ['user'] = { ['get_value'] = function(task) local auser = task:get_user() if not auser then return nil else return auser end end, }, ['to'] = { ['get_value'] = function(task) return task:get_principal_recipient() end, }, } local function gen_rate_key(task, rtype, bucket) local key_t = {tostring(lua_util.round(100000.0 / bucket[1]))} local key_keywords = lua_util.str_split(rtype, '_') local have_user = false for _, v in ipairs(key_keywords) do local ret if keywords[v] and type(keywords[v]['get_value']) == 'function' then ret = keywords[v]['get_value'](task) end if not ret then return nil end if v == 'user' then have_user = true end if type(ret) ~= 'string' then ret = tostring(ret) end table.insert(key_t, ret) end if have_user and not task:get_user() then return nil end return table.concat(key_t, ":") end local function ratelimit_cb(task) if not settings.allow_local and rspamd_lua_utils.is_rspamc_or_controller(task) then return end -- Get initial task data local ip = task:get_from_ip() if ip and ip:is_valid() and settings.whitelisted_ip then if settings.whitelisted_ip:get_key(ip) then -- Do not check whitelisted ip rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP') return end end -- Parse all rcpts local rcpts = task:get_recipients() local rcpts_user = {} if rcpts then fun.each(function(r) fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'}) end, rcpts) if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient') return end end -- Get user (authuser) if settings.whitelisted_user then local auser = task:get_user() if settings.whitelisted_user:get_key(auser) then rspamd_logger.infox(task, 'skip ratelimit for whitelisted user') return end end -- Now create all ratelimit prefixes local prefixes = {} local nprefixes = 0 for k,v in pairs(settings.limits) do for _,bucket in ipairs(v) do local prefix = gen_rate_key(task, k, bucket) if prefix then local hash_len = 24 if hash_len > #prefix then hash_len = #prefix end local hash = settings.prefix .. string.sub(rspamd_hash.create(prefix):base32(), 1, hash_len) prefixes[prefix] = { bucket = bucket, name = k, hash = hash } nprefixes = nprefixes + 1 end end end local function gen_check_cb(prefix, bucket, lim_name) return function(err, data) if err then rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data) elseif type(data) == 'table' and data[1] and data[1] == 1 then -- set symbol only and do NOT soft reject if settings.symbol then task:insert_result(settings.symbol, 0.0, lim_name .. "(" .. prefix .. ")") rspamd_logger.infox(task, 'set_symbol_only: ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)', lim_name, prefix, bucket[2], bucket[1], data[2], data[3], data[4]) return -- set INFO symbol and soft reject elseif settings.info_symbol then task:insert_result(settings.info_symbol, 1.0, lim_name .. "(" .. prefix .. ")") end rspamd_logger.infox(task, 'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)', lim_name, prefix, bucket[2], bucket[1], data[2], data[3], data[4]) task:set_pre_result('soft reject', message_func(task, lim_name, prefix, bucket)) end end end if nprefixes > 0 then -- Save prefixes to the cache to allow update task:cache_set('ratelimit_prefixes', prefixes) local now = rspamd_util.get_time() now = lua_util.round(now * 1000.0) -- Get milliseconds -- Now call check script for all defined prefixes for pr,value in pairs(prefixes) do local bucket = value.bucket local rate = (1.0 / bucket[1]) / 1000.0 -- Leak rate in messages/ms rspamd_logger.debugm(N, task, "check limit %s:%s -> %s (%s/%s)", value.name, pr, value.hash, bucket[2], bucket[1]) lua_redis.exec_redis_script(bucket_check_id, {key = value.hash, task = task, is_write = true}, gen_check_cb(pr, bucket, value.name), {value.hash, tostring(now), tostring(rate), tostring(bucket[2]), tostring(settings.expire)}) end end end local function ratelimit_update_cb(task) local prefixes = task:cache_get('ratelimit_prefixes') if prefixes then local action = task:get_metric_action() local is_spam = true if action == 'soft reject' then -- Already rate limited/greylisted, do nothing rspamd_logger.debugm(N, task, 'already soft rejected, do not update') return elseif action == 'no action' then is_spam = false end local mult_burst = settings.ham_factor_burst local mult_rate = settings.ham_factor_burst if is_spam then mult_burst = settings.spam_factor_burst mult_rate = settings.spam_factor_rate end -- Update each bucket for k, v in pairs(prefixes) do local bucket = v.bucket local function update_bucket_cb(err, data) if err then rspamd_logger.errx(task, 'cannot update rate bucket %s: %s', k, err) else rspamd_logger.debugm(N, task, "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s", v.name, k, v.hash, bucket[2], bucket[1], data[1], data[2], data[3]) end end local now = rspamd_util.get_time() now = lua_util.round(now * 1000.0) -- Get milliseconds lua_redis.exec_redis_script(bucket_update_id, {key = v.hash, task = task, is_write = true}, update_bucket_cb, {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst), tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult), tostring(settings.expire)}) end end end local opts = rspamd_config:get_all_opt(N) if opts then settings = lua_util.override_defaults(settings, opts) if opts['limit'] then rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported') end if opts['rates'] and type(opts['rates']) == 'table' then -- new way of setting limits fun.each(function(t, lim) if type(lim) == 'table' then settings.limits[t] = {} if #lim == 2 and tonumber(lim[1]) and tonumber(lim[2]) then -- Old style ratelimit rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', t) if tonumber(lim[1]) > 0 and tonumber(lim[2]) > 0 then table.insert(settings.limits[t], {1.0/lim[2], lim[1]}) elseif lim[1] ~= 0 then rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', t) else rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', t) end else fun.each(function(l) local plim, size = parse_string_limit(l) if plim then table.insert(settings.limits[t], {plim, size}) end end, lim) end elseif type(lim) == 'string' then local plim, size = parse_string_limit(lim) if plim then settings.limits[t] = { {plim, size} } end end end, opts['rates']) end local enabled_limits = fun.totable(fun.map(function(t) return t end, settings.limits)) rspamd_logger.infox(rspamd_config, 'enabled rate buckets: [%1]', table.concat(enabled_limits, ',')) -- Ret, ret, ret: stupid legacy stuff: -- If we have a string with commas then load it as as static map -- otherwise, apply normal logic of Rspamd maps local wrcpts = opts['whitelisted_rcpts'] if type(wrcpts) == 'string' then if string.find(wrcpts, ',') then settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts') else settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', 'Ratelimit whitelisted rcpts') end elseif type(opts['whitelisted_rcpts']) == 'table' then settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', 'Ratelimit whitelisted rcpts') else -- Stupid default... settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts') end if opts['whitelisted_ip'] then settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix', 'Ratelimit whitelist ip map') end if opts['whitelisted_user'] then settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set', 'Ratelimit whitelist user map') end if opts['custom_keywords'] then settings.custom_keywords = dofile(opts['custom_keywords']) end if opts['message_func'] then message_func = assert(load(opts['message_func']))() end redis_params = lua_redis.parse_redis_server('ratelimit') if not redis_params then rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') lua_util.disable_module(N, "redis") else local s = { type = 'prefilter,nostat', name = 'RATELIMIT_CHECK', priority = 4, callback = ratelimit_cb, flags = 'empty', } if settings.symbol then s.name = settings.symbol elseif settings.info_symbol then s.name = settings.info_symbol end rspamd_config:register_symbol(s) rspamd_config:register_symbol { type = 'idempotent', name = 'RATELIMIT_UPDATE', callback = ratelimit_update_cb, } if settings.custom_keywords then for _, v in pairs(settings.custom_keywords) do if type(v) == 'table' and type(v['init']) == 'function' then v['init']() end end end end end rspamd_config:add_on_load(function(cfg, ev_base, worker) load_scripts(cfg, ev_base) end)