aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lualib/plugins/ratelimit.lua155
-rw-r--r--src/fuzzy_storage.c23
-rw-r--r--src/plugins/lua/ratelimit.lua143
3 files changed, 183 insertions, 138 deletions
diff --git a/lualib/plugins/ratelimit.lua b/lualib/plugins/ratelimit.lua
new file mode 100644
index 000000000..24afed1f8
--- /dev/null
+++ b/lualib/plugins/ratelimit.lua
@@ -0,0 +1,155 @@
+--[[
+Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local ts = require("tableshape").types
+
+local exports = {}
+
+local limit_parser
+local function parse_string_limit(lim, no_error)
+ local function parse_time_suffix(s)
+ if s == 's' then
+ return 1
+ elseif s == 'm' then
+ return 60
+ elseif s == 'h' then
+ return 3600
+ elseif s == 'd' then
+ return 86400
+ end
+ end
+ local function parse_num_suffix(s)
+ if s == '' then
+ return 1
+ elseif s == 'k' then
+ return 1000
+ elseif s == 'm' then
+ return 1000000
+ elseif s == 'g' then
+ return 1000000000
+ end
+ end
+ local lpeg = require "lpeg"
+
+ if not limit_parser then
+ local digit = lpeg.R("09")
+ limit_parser = {}
+ limit_parser.integer = (lpeg.S("+-") ^ -1) *
+ (digit ^ 1)
+ limit_parser.fractional = (lpeg.P(".")) *
+ (digit ^ 1)
+ limit_parser.number = (limit_parser.integer *
+ (limit_parser.fractional ^ -1)) +
+ (lpeg.S("+-") * limit_parser.fractional)
+ limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
+ (limit_parser.number / tonumber) *
+ ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
+ function(acc, val)
+ return acc * val
+ end)
+ limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
+ (limit_parser.number / tonumber) *
+ ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
+ function(acc, val)
+ return acc * val
+ end)
+ limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
+ (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
+ limit_parser.time)
+ end
+ local t = lpeg.match(limit_parser.limit, lim)
+
+ if t and t[1] and t[2] and t[2] ~= 0 then
+ return t[2], t[1]
+ end
+
+ if not no_error then
+ rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
+ end
+
+ return nil
+end
+
+local function str_to_rate(str)
+ local divider, divisor = parse_string_limit(str, false)
+
+ if not divisor then
+ rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str)
+
+ return nil
+ end
+
+ return divisor / divider
+end
+
+local bucket_schema = ts.shape {
+ burst = ts.number + ts.string / lua_util.dehumanize_number,
+ rate = ts.number + ts.string / str_to_rate,
+ skip_recipients = ts.boolean:is_optional(),
+ symbol = ts.string:is_optional(),
+ message = ts.string:is_optional(),
+ skip_soft_reject = ts.boolean:is_optional(),
+}
+
+exports.parse_limit = function(name, data)
+ if type(data) == 'table' then
+ -- 2 cases here:
+ -- * old limit in format [burst, rate]
+ -- * vector of strings in Andrew's string format (removed from 1.8.2)
+ -- * proper bucket table
+ if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
+ -- Old style ratelimit
+ rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
+ if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
+ return {
+ burst = data[1],
+ rate = data[2]
+ }
+ elseif data[1] ~= 0 then
+ rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
+ else
+ rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
+ end
+
+ return nil
+ else
+ local parsed_bucket, err = bucket_schema:transform(data)
+
+ if not parsed_bucket or err then
+ rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s',
+ name, err, data)
+ else
+ return parsed_bucket
+ end
+ end
+ elseif type(data) == 'string' then
+ local rep_rate, burst = parse_string_limit(data)
+ rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s',
+ name, data)
+ if rep_rate and burst then
+ return {
+ burst = burst,
+ rate = burst / rep_rate -- reciprocal
+ }
+ end
+ end
+
+ return nil
+end
+
+return exports \ No newline at end of file
diff --git a/src/fuzzy_storage.c b/src/fuzzy_storage.c
index 2f889c759..257d22bcd 100644
--- a/src/fuzzy_storage.c
+++ b/src/fuzzy_storage.c
@@ -698,6 +698,11 @@ fuzzy_key_dtor(gpointer p)
kh_destroy(fuzzy_key_ids_set, key->forbidden_ids);
}
+ if (key->rl_bucket) {
+ /* TODO: save bucket stats */
+ g_free(key->rl_bucket);
+ }
+
g_free(key);
}
}
@@ -2827,6 +2832,10 @@ fuzzy_add_keypair_from_ucl(const ucl_object_t *obj, khash_t(rspamd_fuzzy_keys_ha
rspamd_inet_address_hash, rspamd_inet_address_equal);
key->stat = keystat;
key->flags_stat = kh_init(fuzzy_key_flag_stat);
+ key->burst = NAN;
+ key->rate = NAN;
+ key->expire = NAN;
+ key->rl_bucket = NULL;
/* Preallocate some space for flags */
kh_resize(fuzzy_key_flag_stat, key->flags_stat, 8);
const unsigned char *pk = rspamd_keypair_component(kp, RSPAMD_KEYPAIR_COMPONENT_PK,
@@ -2874,6 +2883,20 @@ fuzzy_add_keypair_from_ucl(const ucl_object_t *obj, khash_t(rspamd_fuzzy_keys_ha
}
}
}
+
+ /*
+ * TODO: parse ratelimit using Lua code from `ratelimit` plugin to
+ * have unified form of settings
+ */
+ const ucl_object_t *ratelimit = ucl_object_lookup(extensions, "ratelimit");
+
+ if (ratelimit && ucl_object_type(ratelimit) == UCL_STRING) {
+ }
+
+ const ucl_object_t *expire = ucl_object_lookup(extensions, "expire");
+ if (expire && ucl_object_type(expire) == UCL_STRING) {
+ struct tm tm;
+ }
}
msg_debug("loaded keypair %*bs", crypto_box_publickeybytes(), pk);
diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua
index f3331e850..168d8d63a 100644
--- a/src/plugins/lua/ratelimit.lua
+++ b/src/plugins/lua/ratelimit.lua
@@ -29,8 +29,7 @@ local lua_util = require "lua_util"
local lua_verdict = require "lua_verdict"
local rspamd_hash = require "rspamd_cryptobox_hash"
local lua_selectors = require "lua_selectors"
-local ts = require("tableshape").types
-
+local ratelimit_common = require "plugins/ratelimit"
-- A plugin that implements ratelimits using redis
local E = {}
@@ -76,138 +75,6 @@ local function load_scripts(_, _)
bucket_cleanup_id = lua_redis.load_redis_script_from_file(bucket_cleanup_script, redis_params)
end
-local limit_parser
-local function parse_string_limit(lim, no_error)
- local function parse_time_suffix(s)
- if s == 's' then
- return 1
- elseif s == 'm' then
- return 60
- elseif s == 'h' then
- return 3600
- elseif s == 'd' then
- return 86400
- end
- end
- local function parse_num_suffix(s)
- if s == '' then
- return 1
- elseif s == 'k' then
- return 1000
- elseif s == 'm' then
- return 1000000
- elseif s == 'g' then
- return 1000000000
- end
- end
- local lpeg = require "lpeg"
-
- if not limit_parser then
- local digit = lpeg.R("09")
- limit_parser = {}
- limit_parser.integer = (lpeg.S("+-") ^ -1) *
- (digit ^ 1)
- limit_parser.fractional = (lpeg.P(".")) *
- (digit ^ 1)
- limit_parser.number = (limit_parser.integer *
- (limit_parser.fractional ^ -1)) +
- (lpeg.S("+-") * limit_parser.fractional)
- limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
- (limit_parser.number / tonumber) *
- ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
- function(acc, val)
- return acc * val
- end)
- limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
- (limit_parser.number / tonumber) *
- ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
- function(acc, val)
- return acc * val
- end)
- limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
- (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
- limit_parser.time)
- end
- local t = lpeg.match(limit_parser.limit, lim)
-
- if t and t[1] and t[2] and t[2] ~= 0 then
- return t[2], t[1]
- end
-
- if not no_error then
- rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
- end
-
- return nil
-end
-
-local function str_to_rate(str)
- local divider, divisor = parse_string_limit(str, false)
-
- if not divisor then
- rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str)
-
- return nil
- end
-
- return divisor / divider
-end
-
-local bucket_schema = ts.shape {
- burst = ts.number + ts.string / lua_util.dehumanize_number,
- rate = ts.number + ts.string / str_to_rate,
- skip_recipients = ts.boolean:is_optional(),
- symbol = ts.string:is_optional(),
- message = ts.string:is_optional(),
- skip_soft_reject = ts.boolean:is_optional(),
-}
-
-local function parse_limit(name, data)
- if type(data) == 'table' then
- -- 2 cases here:
- -- * old limit in format [burst, rate]
- -- * vector of strings in Andrew's string format (removed from 1.8.2)
- -- * proper bucket table
- if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
- -- Old style ratelimit
- rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
- if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
- return {
- burst = data[1],
- rate = data[2]
- }
- elseif data[1] ~= 0 then
- rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
- else
- rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
- end
-
- return nil
- else
- local parsed_bucket, err = bucket_schema:transform(data)
-
- if not parsed_bucket or err then
- rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s',
- name, err, data)
- else
- return parsed_bucket
- end
- end
- elseif type(data) == 'string' then
- local rep_rate, burst = parse_string_limit(data)
- rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s',
- name, data)
- if rep_rate and burst then
- return {
- burst = burst,
- rate = burst / rep_rate -- reciprocal
- }
- end
- end
-
- return nil
-end
-
--- Check whether this addr is bounce
local function check_bounce(from)
return fun.any(function(b)
@@ -490,7 +357,7 @@ local function ratelimit_cb(task)
local ret, redis_key, bd = pcall(hdl, task)
if ret then
- local bucket = parse_limit(k, bd)
+ local bucket = ratelimit_common.parse_limit(k, bd)
if bucket then
prefixes[redis_key] = make_prefix(redis_key, k, bucket)
end
@@ -718,7 +585,7 @@ if opts then
if lim.bucket[1] then
for _, bucket in ipairs(lim.bucket) do
- local b = parse_limit(t, bucket)
+ local b = ratelimit_common.parse_limit(t, bucket)
if not b then
rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
@@ -729,7 +596,7 @@ if opts then
table.insert(buckets, b)
end
else
- local bucket = parse_limit(t, lim.bucket)
+ local bucket = ratelimit_common.parse_limit(t, lim.bucket)
if not bucket then
rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
@@ -757,7 +624,7 @@ if opts then
end
else
rspamd_logger.warnx(rspamd_config, 'old syntax for ratelimits: %s', lim)
- buckets = parse_limit(t, lim)
+ buckets = ratelimit_common.parse_limit(t, lim)
if buckets then
settings.limits[t] = {
buckets = { buckets }