diff options
-rw-r--r-- | lualib/plugins/ratelimit.lua | 155 | ||||
-rw-r--r-- | src/fuzzy_storage.c | 23 | ||||
-rw-r--r-- | src/plugins/lua/ratelimit.lua | 143 |
3 files changed, 183 insertions, 138 deletions
diff --git a/lualib/plugins/ratelimit.lua b/lualib/plugins/ratelimit.lua new file mode 100644 index 000000000..24afed1f8 --- /dev/null +++ b/lualib/plugins/ratelimit.lua @@ -0,0 +1,155 @@ +--[[ +Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local ts = require("tableshape").types + +local exports = {} + +local limit_parser +local function parse_string_limit(lim, no_error) + local function parse_time_suffix(s) + if s == 's' then + return 1 + elseif s == 'm' then + return 60 + elseif s == 'h' then + return 3600 + elseif s == 'd' then + return 86400 + end + end + local function parse_num_suffix(s) + if s == '' then + return 1 + elseif s == 'k' then + return 1000 + elseif s == 'm' then + return 1000000 + elseif s == 'g' then + return 1000000000 + end + end + local lpeg = require "lpeg" + + if not limit_parser then + local digit = lpeg.R("09") + limit_parser = {} + limit_parser.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + limit_parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + limit_parser.number = (limit_parser.integer * + (limit_parser.fractional ^ -1)) + + (lpeg.S("+-") * limit_parser.fractional) + limit_parser.time = lpeg.Cf(lpeg.Cc(1) * + (limit_parser.number / tonumber) * + ((lpeg.S("smhd") / parse_time_suffix) ^ -1), + function(acc, val) + return acc * val + end) + limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) * + (limit_parser.number / tonumber) * + ((lpeg.S("kmg") / parse_num_suffix) ^ -1), + function(acc, val) + return acc * val + end) + limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number * + (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * + limit_parser.time) + end + local t = lpeg.match(limit_parser.limit, lim) + + if t and t[1] and t[2] and t[2] ~= 0 then + return t[2], t[1] + end + + if not no_error then + rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim) + end + + return nil +end + +local function str_to_rate(str) + local divider, divisor = parse_string_limit(str, false) + + if not divisor then + rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str) + + return nil + end + + return divisor / divider +end + +local bucket_schema = ts.shape { + burst = ts.number + ts.string / lua_util.dehumanize_number, + rate = ts.number + ts.string / str_to_rate, + skip_recipients = ts.boolean:is_optional(), + symbol = ts.string:is_optional(), + message = ts.string:is_optional(), + skip_soft_reject = ts.boolean:is_optional(), +} + +exports.parse_limit = function(name, data) + if type(data) == 'table' then + -- 2 cases here: + -- * old limit in format [burst, rate] + -- * vector of strings in Andrew's string format (removed from 1.8.2) + -- * proper bucket table + if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then + -- Old style ratelimit + rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name) + if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then + return { + burst = data[1], + rate = data[2] + } + elseif data[1] ~= 0 then + rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name) + else + rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name) + end + + return nil + else + local parsed_bucket, err = bucket_schema:transform(data) + + if not parsed_bucket or err then + rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s', + name, err, data) + else + return parsed_bucket + end + end + elseif type(data) == 'string' then + local rep_rate, burst = parse_string_limit(data) + rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s', + name, data) + if rep_rate and burst then + return { + burst = burst, + rate = burst / rep_rate -- reciprocal + } + end + end + + return nil +end + +return exports
\ No newline at end of file diff --git a/src/fuzzy_storage.c b/src/fuzzy_storage.c index 2f889c759..257d22bcd 100644 --- a/src/fuzzy_storage.c +++ b/src/fuzzy_storage.c @@ -698,6 +698,11 @@ fuzzy_key_dtor(gpointer p) kh_destroy(fuzzy_key_ids_set, key->forbidden_ids); } + if (key->rl_bucket) { + /* TODO: save bucket stats */ + g_free(key->rl_bucket); + } + g_free(key); } } @@ -2827,6 +2832,10 @@ fuzzy_add_keypair_from_ucl(const ucl_object_t *obj, khash_t(rspamd_fuzzy_keys_ha rspamd_inet_address_hash, rspamd_inet_address_equal); key->stat = keystat; key->flags_stat = kh_init(fuzzy_key_flag_stat); + key->burst = NAN; + key->rate = NAN; + key->expire = NAN; + key->rl_bucket = NULL; /* Preallocate some space for flags */ kh_resize(fuzzy_key_flag_stat, key->flags_stat, 8); const unsigned char *pk = rspamd_keypair_component(kp, RSPAMD_KEYPAIR_COMPONENT_PK, @@ -2874,6 +2883,20 @@ fuzzy_add_keypair_from_ucl(const ucl_object_t *obj, khash_t(rspamd_fuzzy_keys_ha } } } + + /* + * TODO: parse ratelimit using Lua code from `ratelimit` plugin to + * have unified form of settings + */ + const ucl_object_t *ratelimit = ucl_object_lookup(extensions, "ratelimit"); + + if (ratelimit && ucl_object_type(ratelimit) == UCL_STRING) { + } + + const ucl_object_t *expire = ucl_object_lookup(extensions, "expire"); + if (expire && ucl_object_type(expire) == UCL_STRING) { + struct tm tm; + } } msg_debug("loaded keypair %*bs", crypto_box_publickeybytes(), pk); diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index f3331e850..168d8d63a 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -29,8 +29,7 @@ local lua_util = require "lua_util" local lua_verdict = require "lua_verdict" local rspamd_hash = require "rspamd_cryptobox_hash" local lua_selectors = require "lua_selectors" -local ts = require("tableshape").types - +local ratelimit_common = require "plugins/ratelimit" -- A plugin that implements ratelimits using redis local E = {} @@ -76,138 +75,6 @@ local function load_scripts(_, _) bucket_cleanup_id = lua_redis.load_redis_script_from_file(bucket_cleanup_script, redis_params) end -local limit_parser -local function parse_string_limit(lim, no_error) - local function parse_time_suffix(s) - if s == 's' then - return 1 - elseif s == 'm' then - return 60 - elseif s == 'h' then - return 3600 - elseif s == 'd' then - return 86400 - end - end - local function parse_num_suffix(s) - if s == '' then - return 1 - elseif s == 'k' then - return 1000 - elseif s == 'm' then - return 1000000 - elseif s == 'g' then - return 1000000000 - end - end - local lpeg = require "lpeg" - - if not limit_parser then - local digit = lpeg.R("09") - limit_parser = {} - limit_parser.integer = (lpeg.S("+-") ^ -1) * - (digit ^ 1) - limit_parser.fractional = (lpeg.P(".")) * - (digit ^ 1) - limit_parser.number = (limit_parser.integer * - (limit_parser.fractional ^ -1)) + - (lpeg.S("+-") * limit_parser.fractional) - limit_parser.time = lpeg.Cf(lpeg.Cc(1) * - (limit_parser.number / tonumber) * - ((lpeg.S("smhd") / parse_time_suffix) ^ -1), - function(acc, val) - return acc * val - end) - limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) * - (limit_parser.number / tonumber) * - ((lpeg.S("kmg") / parse_num_suffix) ^ -1), - function(acc, val) - return acc * val - end) - limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number * - (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * - limit_parser.time) - end - local t = lpeg.match(limit_parser.limit, lim) - - if t and t[1] and t[2] and t[2] ~= 0 then - return t[2], t[1] - end - - if not no_error then - rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim) - end - - return nil -end - -local function str_to_rate(str) - local divider, divisor = parse_string_limit(str, false) - - if not divisor then - rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str) - - return nil - end - - return divisor / divider -end - -local bucket_schema = ts.shape { - burst = ts.number + ts.string / lua_util.dehumanize_number, - rate = ts.number + ts.string / str_to_rate, - skip_recipients = ts.boolean:is_optional(), - symbol = ts.string:is_optional(), - message = ts.string:is_optional(), - skip_soft_reject = ts.boolean:is_optional(), -} - -local function parse_limit(name, data) - if type(data) == 'table' then - -- 2 cases here: - -- * old limit in format [burst, rate] - -- * vector of strings in Andrew's string format (removed from 1.8.2) - -- * proper bucket table - if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then - -- Old style ratelimit - rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name) - if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then - return { - burst = data[1], - rate = data[2] - } - elseif data[1] ~= 0 then - rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name) - else - rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name) - end - - return nil - else - local parsed_bucket, err = bucket_schema:transform(data) - - if not parsed_bucket or err then - rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s', - name, err, data) - else - return parsed_bucket - end - end - elseif type(data) == 'string' then - local rep_rate, burst = parse_string_limit(data) - rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s', - name, data) - if rep_rate and burst then - return { - burst = burst, - rate = burst / rep_rate -- reciprocal - } - end - end - - return nil -end - --- Check whether this addr is bounce local function check_bounce(from) return fun.any(function(b) @@ -490,7 +357,7 @@ local function ratelimit_cb(task) local ret, redis_key, bd = pcall(hdl, task) if ret then - local bucket = parse_limit(k, bd) + local bucket = ratelimit_common.parse_limit(k, bd) if bucket then prefixes[redis_key] = make_prefix(redis_key, k, bucket) end @@ -718,7 +585,7 @@ if opts then if lim.bucket[1] then for _, bucket in ipairs(lim.bucket) do - local b = parse_limit(t, bucket) + local b = ratelimit_common.parse_limit(t, bucket) if not b then rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"', @@ -729,7 +596,7 @@ if opts then table.insert(buckets, b) end else - local bucket = parse_limit(t, lim.bucket) + local bucket = ratelimit_common.parse_limit(t, lim.bucket) if not bucket then rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"', @@ -757,7 +624,7 @@ if opts then end else rspamd_logger.warnx(rspamd_config, 'old syntax for ratelimits: %s', lim) - buckets = parse_limit(t, lim) + buckets = ratelimit_common.parse_limit(t, lim) if buckets then settings.limits[t] = { buckets = { buckets } |