--[[ Copyright (c) 2011-2017, Vsevolod Stakhov Copyright (c) 2016-2017, Andrew Lewis Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ]]-- if confighelp then return end local rspamd_logger = require "rspamd_logger" local rspamd_util = require "rspamd_util" local rspamd_lua_utils = require "lua_util" local lua_redis = require "lua_redis" local fun = require "fun" local lua_maps = require "lua_maps" local lua_util = require "lua_util" local rspamd_hash = require "rspamd_cryptobox_hash" local lua_selectors = require "lua_selectors" -- A plugin that implements ratelimits using redis local E = {} local N = 'ratelimit' local redis_params -- Senders that are considered as bounce local settings = { bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' }, -- Do not check ratelimits for these recipients whitelisted_rcpts = { 'postmaster', 'mailer-daemon' }, prefix = 'RL', ham_factor_rate = 1.01, spam_factor_rate = 0.99, ham_factor_burst = 1.02, spam_factor_burst = 0.98, max_rate_mult = 5, max_bucket_mult = 10, expire = 60 * 60 * 24 * 2, -- 2 days by default limits = {}, allow_local = false, } -- Checks bucket, updating it if needed -- KEYS[1] - prefix to update, e.g. RL__ -- KEYS[2] - current time in milliseconds -- KEYS[3] - bucket leak rate (messages per millisecond) -- KEYS[4] - bucket burst -- KEYS[5] - expire for a bucket -- return 1 if message should be ratelimited and 0 if not -- Redis keys used: -- l - last hit -- b - current burst -- dr - current dynamic rate multiplier (*10000) -- db - current dynamic burst multiplier (*10000) local bucket_check_script = [[ local last = redis.call('HGET', KEYS[1], 'l') local now = tonumber(KEYS[2]) local dynr, dynb, leaked = 0, 0, 0 if not last then -- New bucket redis.call('HSET', KEYS[1], 'l', KEYS[2]) redis.call('HSET', KEYS[1], 'b', '0') redis.call('HSET', KEYS[1], 'dr', '10000') redis.call('HSET', KEYS[1], 'db', '10000') redis.call('EXPIRE', KEYS[1], KEYS[5]) return {0, '0', '1', '1', '0'} end last = tonumber(last) local burst = tonumber(redis.call('HGET', KEYS[1], 'b')) -- Perform leak if burst > 0 then if last < tonumber(KEYS[2]) then local rate = tonumber(KEYS[3]) dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0 rate = rate * dynr leaked = ((now - last) * rate) burst = burst - leaked redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked)) redis.call('HSET', KEYS[1], 'l', KEYS[2]) end else burst = 0 redis.call('HSET', KEYS[1], 'b', '0') end dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0 if (burst + 1) > tonumber(KEYS[4]) * dynb then return {1, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)} end return {0, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)} ]] local bucket_check_id -- Updates a bucket -- KEYS[1] - prefix to update, e.g. RL__ -- KEYS[2] - current time in milliseconds -- KEYS[3] - dynamic rate multiplier -- KEYS[4] - dynamic burst multiplier -- KEYS[5] - max dyn rate (min: 1/x) -- KEYS[6] - max burst rate (min: 1/x) -- KEYS[7] - expire for a bucket -- Redis keys used: -- l - last hit -- b - current burst -- dr - current dynamic rate multiplier -- db - current dynamic burst multiplier local bucket_update_script = [[ local last = redis.call('HGET', KEYS[1], 'l') local now = tonumber(KEYS[2]) if not last then -- New bucket redis.call('HSET', KEYS[1], 'l', KEYS[2]) redis.call('HSET', KEYS[1], 'b', '1') redis.call('HSET', KEYS[1], 'dr', '10000') redis.call('HSET', KEYS[1], 'db', '10000') redis.call('EXPIRE', KEYS[1], KEYS[7]) return {1, 1, 1} end local burst = tonumber(redis.call('HGET', KEYS[1], 'b')) local db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000 local dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000 if dr < tonumber(KEYS[5]) and dr > 1.0 / tonumber(KEYS[5]) then dr = dr * tonumber(KEYS[3]) redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000))) end if db < tonumber(KEYS[6]) and db > 1.0 / tonumber(KEYS[6]) then db = db * tonumber(KEYS[4]) redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000))) end redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1) redis.call('HSET', KEYS[1], 'l', KEYS[2]) redis.call('EXPIRE', KEYS[1], KEYS[7]) return {tostring(burst), tostring(dr), tostring(db)} ]] local bucket_update_id -- message_func(task, limit_type, prefix, bucket) local message_func = function(_, limit_type, _, _) return string.format('Ratelimit "%s" exceeded', limit_type) end local function load_scripts(cfg, ev_base) bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params) bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params) end local limit_parser local function parse_string_limit(lim, no_error) local function parse_time_suffix(s) if s == 's' then return 1 elseif s == 'm' then return 60 elseif s == 'h' then return 3600 elseif s == 'd' then return 86400 end end local function parse_num_suffix(s) if s == '' then return 1 elseif s == 'k' then return 1000 elseif s == 'm' then return 1000000 elseif s == 'g' then return 1000000000 end end local lpeg = require "lpeg" if not limit_parser then local digit = lpeg.R("09") limit_parser = {} limit_parser.integer = (lpeg.S("+-") ^ -1) * (digit ^ 1) limit_parser.fractional = (lpeg.P(".") ) * (digit ^ 1) limit_parser.number = (limit_parser.integer * (limit_parser.fractional ^ -1)) + (lpeg.S("+-") * limit_parser.fractional) limit_parser.time = lpeg.Cf(lpeg.Cc(1) * (limit_parser.number / tonumber) * ((lpeg.S("smhd") / parse_time_suffix) ^ -1), function (acc, val) return acc * val end) limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) * (limit_parser.number / tonumber) * ((lpeg.S("kmg") / parse_num_suffix) ^ -1), function (acc, val) return acc * val end) limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number * (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * limit_parser.time) end local t = lpeg.match(limit_parser.limit, lim) if t and t[1] and t[2] and t[2] ~= 0 then return t[2], t[1] end if not no_error then rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim) end return nil end local function parse_limit(name, data) local buckets = {} if type(data) == 'table' then -- 3 cases here: -- * old limit in format [burst, rate] -- * vector of strings in Andrew's string format -- * proper bucket table if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then -- Old style ratelimit rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name) if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then table.insert(buckets, { burst = data[1], rate = data[2] }) elseif data[1] ~= 0 then rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name) else rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name) end else -- Recursively map parse_limit and flatten the list fun.each(function(l) -- Flatten list for _,b in ipairs(l) do table.insert(buckets, b) end end, fun.map(function(d) return parse_limit(d, name) end, data)) end elseif type(data) == 'string' then local rep_rate, burst = parse_string_limit(data) if rep_rate and burst then table.insert(buckets, { burst = burst, rate = 1.0 / rep_rate -- reciprocal }) end end -- Filter valid return fun.totable(fun.filter(function(val) return type(val.burst) == 'number' and type(val.rate) == 'number' end, buckets)) end --- Check whether this addr is bounce local function check_bounce(from) return fun.any(function(b) return b == from end, settings.bounce_senders) end local keywords = { ['ip'] = { ['get_value'] = function(task) local ip = task:get_ip() if ip and ip:is_valid() then return tostring(ip) end return nil end, }, ['rip'] = { ['get_value'] = function(task) local ip = task:get_ip() if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end return nil end, }, ['from'] = { ['get_value'] = function(task) local from = task:get_from(0) if ((from or E)[1] or E).addr then return string.lower(from[1]['addr']) end return nil end, }, ['bounce'] = { ['get_value'] = function(task) local from = task:get_from(0) if not ((from or E)[1] or E).user then return '_' end if check_bounce(from[1]['user']) then return '_' else return nil end end, }, ['asn'] = { ['get_value'] = function(task) local asn = task:get_mempool():get_variable('asn') if not asn then return nil else return asn end end, }, ['user'] = { ['get_value'] = function(task) local auser = task:get_user() if not auser then return nil else return auser end end, }, ['to'] = { ['get_value'] = function(task) return task:get_principal_recipient() end, }, ['digest'] = { ['get_value'] = function(task) return task:get_digest() end, }, ['attachments'] = { ['get_value'] = function(task) local parts = task:get_parts() or E local digests = {} for _,p in ipairs(parts) do if p:get_filename() then table.insert(digests, p:get_digest()) end end if #digests > 0 then return table.concat(digests, '') end return nil end, }, ['files'] = { ['get_value'] = function(task) local parts = task:get_parts() or E local files = {} for _,p in ipairs(parts) do local fname = p:get_filename() if fname then table.insert(files, fname) end end if #files > 0 then return table.concat(files, ':') end return nil end, }, } local function gen_rate_key(task, rtype, bucket) local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))} local key_keywords = lua_util.str_split(rtype, '_') local have_user = false for _, v in ipairs(key_keywords) do local ret if keywords[v] and type(keywords[v]['get_value']) == 'function' then ret = keywords[v]['get_value'](task) end if not ret then return nil end if v == 'user' then have_user = true end if type(ret) ~= 'string' then ret = tostring(ret) end table.insert(key_t, ret) end if have_user and not task:get_user() then return nil end return table.concat(key_t, ":") end local function make_prefix(redis_key, name, bucket) local hash_len = 24 if hash_len > #redis_key then hash_len = #redis_key end local hash = settings.prefix .. string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len) -- Fill defaults if not bucket.spam_factor_rate then bucket.spam_factor_rate = settings.spam_factor_rate end if not bucket.ham_factor_rate then bucket.ham_factor_rate = settings.ham_factor_rate end if not bucket.spam_factor_burst then bucket.spam_factor_burst = settings.spam_factor_burst end if not bucket.ham_factor_burst then bucket.ham_factor_burst = settings.ham_factor_burst end return { bucket = bucket, name = name, hash = hash } end local function limit_to_prefixes(task, k, v, prefixes) local n = 0 for _,bucket in ipairs(v.buckets) do if v.selector then local selectors = lua_selectors.process_selectors(task, v.selector) if selectors then local combined = lua_selectors.combine_selectors(task, selectors, ':') if type(combined) == 'string' then prefixes[combined] = make_prefix(combined, k, bucket) n = n + 1 else fun.each(function(p) prefixes[p] = make_prefix(p, k, bucket) n = n + 1 end, combined) end end else local prefix = gen_rate_key(task, k, bucket) if prefix then if type(prefix) == 'string' then prefixes[prefix] = make_prefix(prefix, k, bucket) n = n + 1 else fun.each(function(p) prefixes[p] = make_prefix(p, k, bucket) n = n + 1 end, prefix) end end end end return n end local function ratelimit_cb(task) if not settings.allow_local and rspamd_lua_utils.is_rspamc_or_controller(task) then return end -- Get initial task data local ip = task:get_from_ip() if ip and ip:is_valid() and settings.whitelisted_ip then if settings.whitelisted_ip:get_key(ip) then -- Do not check whitelisted ip rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP') return end end -- Parse all rcpts local rcpts = task:get_recipients() local rcpts_user = {} if rcpts then fun.each(function(r) fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'}) end, rcpts) if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient') return end end -- Get user (authuser) if settings.whitelisted_user then local auser = task:get_user() if settings.whitelisted_user:get_key(auser) then rspamd_logger.infox(task, 'skip ratelimit for whitelisted user') return end end -- Now create all ratelimit prefixes local prefixes = {} local nprefixes = 0 for k,v in pairs(settings.limits) do nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes) end for k, hdl in pairs(settings.custom_keywords or E) do local ret, redis_key, bd = pcall(hdl, task) if ret then local bucket = parse_limit(k, bd) if bucket[1] then prefixes[redis_key] = make_prefix(redis_key, k, bucket[1]) end nprefixes = nprefixes + 1 else rspamd_logger.errx(task, 'cannot call handler for %s: %s', k, redis_key) end end local function gen_check_cb(prefix, bucket, lim_name) return function(err, data) if err then rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data) elseif type(data) == 'table' and data[1] then lua_util.debugm(N, task, "got reply for limit %s (%s / %s); %s burst, %s:%s dyn, %s leaked", prefix, bucket.burst, bucket.rate, data[2], data[3], data[4], data[5]) if data[1] == 1 then -- set symbol only and do NOT soft reject if settings.symbol then task:insert_result(settings.symbol, 0.0, lim_name .. "(" .. prefix .. ")") rspamd_logger.infox(task, 'set_symbol_only: ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)', lim_name, prefix, bucket.burst, bucket.rate, data[2], data[3], data[4]) return -- set INFO symbol and soft reject elseif settings.info_symbol then task:insert_result(settings.info_symbol, 1.0, lim_name .. "(" .. prefix .. ")") end rspamd_logger.infox(task, 'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)', lim_name, prefix, bucket.burst, bucket.rate, data[2], data[3], data[4]) task:set_pre_result('soft reject', message_func(task, lim_name, prefix, bucket)) end end end end -- Don't do anything if pre-result has been already set if task:has_pre_result() then return end if nprefixes > 0 then -- Save prefixes to the cache to allow update task:cache_set('ratelimit_prefixes', prefixes) local now = rspamd_util.get_time() now = lua_util.round(now * 1000.0) -- Get milliseconds -- Now call check script for all defined prefixes for pr,value in pairs(prefixes) do local bucket = value.bucket local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms lua_util.debugm(N, task, "check limit %s:%s -> %s (%s/%s)", value.name, pr, value.hash, bucket.burst, bucket.rate) lua_redis.exec_redis_script(bucket_check_id, {key = value.hash, task = task, is_write = true}, gen_check_cb(pr, bucket, value.name), {value.hash, tostring(now), tostring(rate), tostring(bucket.burst), tostring(settings.expire)}) end end end local function ratelimit_update_cb(task) local prefixes = task:cache_get('ratelimit_prefixes') if prefixes then if task:has_pre_result() then -- Already rate limited/greylisted, do nothing lua_util.debugm(N, task, 'pre-action has been set, do not update') return end local is_spam = not (task:get_metric_action() == 'no action') -- Update each bucket for k, v in pairs(prefixes) do local bucket = v.bucket local function update_bucket_cb(err, data) if err then rspamd_logger.errx(task, 'cannot update rate bucket %s: %s', k, err) else lua_util.debugm(N, task, "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s", v.name, k, v.hash, bucket.burst, bucket.rate, data[1], data[2], data[3]) end end local now = rspamd_util.get_time() now = lua_util.round(now * 1000.0) -- Get milliseconds local mult_burst = bucket.ham_factor_burst or 1.0 local mult_rate = bucket.ham_factor_burst or 1.0 if is_spam then mult_burst = bucket.spam_factor_burst or 1.0 mult_rate = bucket.spam_factor_rate or 1.0 end lua_redis.exec_redis_script(bucket_update_id, {key = v.hash, task = task, is_write = true}, update_bucket_cb, {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst), tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult), tostring(settings.expire)}) end end end local opts = rspamd_config:get_all_opt(N) if opts then settings = lua_util.override_defaults(settings, opts) if opts['limit'] then rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported') end if opts['rates'] and type(opts['rates']) == 'table' then -- new way of setting limits fun.each(function(t, lim) local buckets if type(lim) == 'table' and lim.selector and lim.bucket then local selector = lua_selectors.parse_selector(rspamd_config, lim.selector) if not selector then rspamd_logger.errx(rspamd_config, 'bad ratelimit selector for %s: "%s"', t, lim.selector) return end local bucket = parse_limit(t, lim.bucket) if not bucket then rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"', t, lim.bucket) return end settings.limits[t] = { selector = selector, buckets = bucket } else buckets = parse_limit(t, lim) if buckets and #buckets > 0 then settings.limits[t] = { buckets = buckets } end end end, opts['rates']) end local enabled_limits = fun.totable(fun.map(function(t) return t end, settings.limits)) rspamd_logger.infox(rspamd_config, 'enabled rate buckets: [%1]', table.concat(enabled_limits, ',')) -- Ret, ret, ret: stupid legacy stuff: -- If we have a string with commas then load it as as static map -- otherwise, apply normal logic of Rspamd maps local wrcpts = opts['whitelisted_rcpts'] if type(wrcpts) == 'string' then if string.find(wrcpts, ',') then settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts') else settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', 'Ratelimit whitelisted rcpts') end elseif type(opts['whitelisted_rcpts']) == 'table' then settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', 'Ratelimit whitelisted rcpts') else -- Stupid default... settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts') end if opts['whitelisted_ip'] then settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix', 'Ratelimit whitelist ip map') end if opts['whitelisted_user'] then settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set', 'Ratelimit whitelist user map') end settings.custom_keywords = {} if opts['custom_keywords'] then local ret, res_or_err = pcall(loadfile(opts['custom_keywords'])) if ret then opts['custom_keywords'] = {} if type(res_or_err) == 'table' then for k,hdl in pairs(res_or_err) do settings['custom_keywords'][k] = hdl end elseif type(res_or_err) == 'function' then settings['custom_keywords']['custom'] = res_or_err end else rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s', opts['custom_keywords'], res_or_err) settings['custom_keywords'] = {} end end if opts['message_func'] then message_func = assert(load(opts['message_func']))() end redis_params = lua_redis.parse_redis_server('ratelimit') if not redis_params then rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') lua_util.disable_module(N, "redis") else local s = { type = 'prefilter,nostat', name = 'RATELIMIT_CHECK', priority = 7, callback = ratelimit_cb, flags = 'empty', } if settings.symbol then s.name = settings.symbol elseif settings.info_symbol then s.name = settings.info_symbol end rspamd_config:register_symbol(s) rspamd_config:register_symbol { type = 'idempotent', name = 'RATELIMIT_UPDATE', callback = ratelimit_update_cb, } end end rspamd_config:add_on_load(function(cfg, ev_base, worker) load_scripts(cfg, ev_base) end)