You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ratelimit.lua 27KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Copyright (c) 2016-2017, Andrew Lewis <nerf@judo.za.org>
  4. Licensed under the Apache License, Version 2.0 (the "License");
  5. you may not use this file except in compliance with the License.
  6. You may obtain a copy of the License at
  7. http://www.apache.org/licenses/LICENSE-2.0
  8. Unless required by applicable law or agreed to in writing, software
  9. distributed under the License is distributed on an "AS IS" BASIS,
  10. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. See the License for the specific language governing permissions and
  12. limitations under the License.
  13. ]]--
  14. if confighelp then
  15. return
  16. end
  17. local rspamd_logger = require "rspamd_logger"
  18. local rspamd_util = require "rspamd_util"
  19. local rspamd_lua_utils = require "lua_util"
  20. local lua_redis = require "lua_redis"
  21. local fun = require "fun"
  22. local lua_maps = require "lua_maps"
  23. local lua_util = require "lua_util"
  24. local lua_verdict = require "lua_verdict"
  25. local rspamd_hash = require "rspamd_cryptobox_hash"
  26. local lua_selectors = require "lua_selectors"
  27. local ts = require("tableshape").types
  28. -- A plugin that implements ratelimits using redis
  29. local E = {}
  30. local N = 'ratelimit'
  31. local redis_params
  32. -- Senders that are considered as bounce
  33. local settings = {
  34. bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' },
  35. -- Do not check ratelimits for these recipients
  36. whitelisted_rcpts = { 'postmaster', 'mailer-daemon' },
  37. prefix = 'RL',
  38. ham_factor_rate = 1.01,
  39. spam_factor_rate = 0.99,
  40. ham_factor_burst = 1.02,
  41. spam_factor_burst = 0.98,
  42. max_rate_mult = 5,
  43. max_bucket_mult = 10,
  44. expire = 60 * 60 * 24 * 2, -- 2 days by default
  45. limits = {},
  46. allow_local = false,
  47. prefilter = true,
  48. }
  49. -- Checks bucket, updating it if needed
  50. -- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
  51. -- KEYS[2] - current time in milliseconds
  52. -- KEYS[3] - bucket leak rate (messages per millisecond)
  53. -- KEYS[4] - bucket burst
  54. -- KEYS[5] - expire for a bucket
  55. -- KEYS[6] - number of recipients
  56. -- return 1 if message should be ratelimited and 0 if not
  57. -- Redis keys used:
  58. -- l - last hit
  59. -- b - current burst
  60. -- dr - current dynamic rate multiplier (*10000)
  61. -- db - current dynamic burst multiplier (*10000)
  62. local bucket_check_script = [[
  63. local last = redis.call('HGET', KEYS[1], 'l')
  64. local now = tonumber(KEYS[2])
  65. local dynr, dynb, leaked = 0, 0, 0
  66. if not last then
  67. -- New bucket
  68. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  69. redis.call('HSET', KEYS[1], 'b', '0')
  70. redis.call('HSET', KEYS[1], 'dr', '10000')
  71. redis.call('HSET', KEYS[1], 'db', '10000')
  72. redis.call('EXPIRE', KEYS[1], KEYS[5])
  73. return {0, '0', '1', '1', '0'}
  74. end
  75. last = tonumber(last)
  76. local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
  77. -- Perform leak
  78. if burst > 0 then
  79. if last < tonumber(KEYS[2]) then
  80. local rate = tonumber(KEYS[3])
  81. dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0
  82. if dynr == 0 then dynr = 0.0001 end
  83. rate = rate * dynr
  84. leaked = ((now - last) * rate)
  85. if leaked > burst then leaked = burst end
  86. burst = burst - leaked
  87. redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))
  88. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  89. end
  90. dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0
  91. if dynb == 0 then dynb = 0.0001 end
  92. if burst > 0 and (burst + tonumber(KEYS[6])) > tonumber(KEYS[4]) * dynb then
  93. return {1, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)}
  94. end
  95. else
  96. burst = 0
  97. redis.call('HSET', KEYS[1], 'b', '0')
  98. end
  99. return {0, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked)}
  100. ]]
  101. local bucket_check_id
  102. -- Updates a bucket
  103. -- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
  104. -- KEYS[2] - current time in milliseconds
  105. -- KEYS[3] - dynamic rate multiplier
  106. -- KEYS[4] - dynamic burst multiplier
  107. -- KEYS[5] - max dyn rate (min: 1/x)
  108. -- KEYS[6] - max burst rate (min: 1/x)
  109. -- KEYS[7] - expire for a bucket
  110. -- KEYS[8] - number of recipients (or increase rate)
  111. -- Redis keys used:
  112. -- l - last hit
  113. -- b - current burst
  114. -- dr - current dynamic rate multiplier
  115. -- db - current dynamic burst multiplier
  116. local bucket_update_script = [[
  117. local last = redis.call('HGET', KEYS[1], 'l')
  118. local now = tonumber(KEYS[2])
  119. if not last then
  120. -- New bucket
  121. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  122. redis.call('HSET', KEYS[1], 'b', '1')
  123. redis.call('HSET', KEYS[1], 'dr', '10000')
  124. redis.call('HSET', KEYS[1], 'db', '10000')
  125. redis.call('EXPIRE', KEYS[1], KEYS[7])
  126. return {1, 1, 1}
  127. end
  128. local dr, db = 1.0, 1.0
  129. if tonumber(KEYS[5]) > 1 then
  130. local rate_mult = tonumber(KEYS[3])
  131. local rate_limit = tonumber(KEYS[5])
  132. dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000
  133. if rate_mult > 1.0 and dr < rate_limit then
  134. dr = dr * rate_mult
  135. if dr > 0.0001 then
  136. redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
  137. else
  138. redis.call('HSET', KEYS[1], 'dr', '1')
  139. end
  140. elseif rate_mult < 1.0 and dr > (1.0 / rate_limit) then
  141. dr = dr * rate_mult
  142. if dr > 0.0001 then
  143. redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
  144. else
  145. redis.call('HSET', KEYS[1], 'dr', '1')
  146. end
  147. end
  148. end
  149. if tonumber(KEYS[6]) > 1 then
  150. local rate_mult = tonumber(KEYS[4])
  151. local rate_limit = tonumber(KEYS[6])
  152. db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000
  153. if rate_mult > 1.0 and db < rate_limit then
  154. db = db * rate_mult
  155. if db > 0.0001 then
  156. redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
  157. else
  158. redis.call('HSET', KEYS[1], 'db', '1')
  159. end
  160. elseif rate_mult < 1.0 and db > (1.0 / rate_limit) then
  161. db = db * rate_mult
  162. if db > 0.0001 then
  163. redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
  164. else
  165. redis.call('HSET', KEYS[1], 'db', '1')
  166. end
  167. end
  168. end
  169. local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
  170. if burst < 0 then burst = 0 end
  171. redis.call('HINCRBYFLOAT', KEYS[1], 'b', tonumber(KEYS[8]))
  172. redis.call('HSET', KEYS[1], 'l', KEYS[2])
  173. redis.call('EXPIRE', KEYS[1], KEYS[7])
  174. return {tostring(burst), tostring(dr), tostring(db)}
  175. ]]
  176. local bucket_update_id
  177. -- message_func(task, limit_type, prefix, bucket, limit_key)
  178. local message_func = function(_, limit_type, _, _, _)
  179. return string.format('Ratelimit "%s" exceeded', limit_type)
  180. end
  181. local function load_scripts(cfg, ev_base)
  182. bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params)
  183. bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params)
  184. end
  185. local limit_parser
  186. local function parse_string_limit(lim, no_error)
  187. local function parse_time_suffix(s)
  188. if s == 's' then
  189. return 1
  190. elseif s == 'm' then
  191. return 60
  192. elseif s == 'h' then
  193. return 3600
  194. elseif s == 'd' then
  195. return 86400
  196. end
  197. end
  198. local function parse_num_suffix(s)
  199. if s == '' then
  200. return 1
  201. elseif s == 'k' then
  202. return 1000
  203. elseif s == 'm' then
  204. return 1000000
  205. elseif s == 'g' then
  206. return 1000000000
  207. end
  208. end
  209. local lpeg = require "lpeg"
  210. if not limit_parser then
  211. local digit = lpeg.R("09")
  212. limit_parser = {}
  213. limit_parser.integer =
  214. (lpeg.S("+-") ^ -1) *
  215. (digit ^ 1)
  216. limit_parser.fractional =
  217. (lpeg.P(".") ) *
  218. (digit ^ 1)
  219. limit_parser.number =
  220. (limit_parser.integer *
  221. (limit_parser.fractional ^ -1)) +
  222. (lpeg.S("+-") * limit_parser.fractional)
  223. limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
  224. (limit_parser.number / tonumber) *
  225. ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
  226. function (acc, val) return acc * val end)
  227. limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
  228. (limit_parser.number / tonumber) *
  229. ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
  230. function (acc, val) return acc * val end)
  231. limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
  232. (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
  233. limit_parser.time)
  234. end
  235. local t = lpeg.match(limit_parser.limit, lim)
  236. if t and t[1] and t[2] and t[2] ~= 0 then
  237. return t[2], t[1]
  238. end
  239. if not no_error then
  240. rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
  241. end
  242. return nil
  243. end
  244. local function str_to_rate(str)
  245. local divider,divisor = parse_string_limit(str, false)
  246. if not divisor then
  247. rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str)
  248. return nil
  249. end
  250. return divisor / divider
  251. end
  252. local bucket_schema = ts.shape{
  253. burst = ts.number + ts.string / lua_util.dehumanize_number,
  254. rate = ts.number + ts.string / str_to_rate,
  255. skip_recipients = ts.boolean:is_optional(),
  256. symbol = ts.string:is_optional(),
  257. message = ts.string:is_optional(),
  258. skip_soft_reject = ts.boolean:is_optional(),
  259. }
  260. local function parse_limit(name, data)
  261. if type(data) == 'table' then
  262. -- 2 cases here:
  263. -- * old limit in format [burst, rate]
  264. -- * vector of strings in Andrew's string format (removed from 1.8.2)
  265. -- * proper bucket table
  266. if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
  267. -- Old style ratelimit
  268. rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
  269. if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
  270. return {
  271. burst = data[1],
  272. rate = data[2]
  273. }
  274. elseif data[1] ~= 0 then
  275. rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
  276. else
  277. rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
  278. end
  279. return nil
  280. else
  281. local parsed_bucket,err = bucket_schema:transform(data)
  282. if not parsed_bucket or err then
  283. rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s',
  284. name, err, data)
  285. else
  286. return parsed_bucket
  287. end
  288. end
  289. elseif type(data) == 'string' then
  290. local rep_rate, burst = parse_string_limit(data)
  291. rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s',
  292. name, data)
  293. if rep_rate and burst then
  294. return {
  295. burst = burst,
  296. rate = burst / rep_rate -- reciprocal
  297. }
  298. end
  299. end
  300. return nil
  301. end
  302. --- Check whether this addr is bounce
  303. local function check_bounce(from)
  304. return fun.any(function(b) return b == from end, settings.bounce_senders)
  305. end
  306. local keywords = {
  307. ['ip'] = {
  308. ['get_value'] = function(task)
  309. local ip = task:get_ip()
  310. if ip and ip:is_valid() then return tostring(ip) end
  311. return nil
  312. end,
  313. },
  314. ['rip'] = {
  315. ['get_value'] = function(task)
  316. local ip = task:get_ip()
  317. if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end
  318. return nil
  319. end,
  320. },
  321. ['from'] = {
  322. ['get_value'] = function(task)
  323. local from = task:get_from(0)
  324. if ((from or E)[1] or E).addr then
  325. return string.lower(from[1]['addr'])
  326. end
  327. return nil
  328. end,
  329. },
  330. ['bounce'] = {
  331. ['get_value'] = function(task)
  332. local from = task:get_from(0)
  333. if not ((from or E)[1] or E).user then
  334. return '_'
  335. end
  336. if check_bounce(from[1]['user']) then return '_' else return nil end
  337. end,
  338. },
  339. ['asn'] = {
  340. ['get_value'] = function(task)
  341. local asn = task:get_mempool():get_variable('asn')
  342. if not asn then
  343. return nil
  344. else
  345. return asn
  346. end
  347. end,
  348. },
  349. ['user'] = {
  350. ['get_value'] = function(task)
  351. local auser = task:get_user()
  352. if not auser then
  353. return nil
  354. else
  355. return auser
  356. end
  357. end,
  358. },
  359. ['to'] = {
  360. ['get_value'] = function(task)
  361. return task:get_principal_recipient()
  362. end,
  363. },
  364. ['digest'] = {
  365. ['get_value'] = function(task)
  366. return task:get_digest()
  367. end,
  368. },
  369. ['attachments'] = {
  370. ['get_value'] = function(task)
  371. local parts = task:get_parts() or E
  372. local digests = {}
  373. for _,p in ipairs(parts) do
  374. if p:get_filename() then
  375. table.insert(digests, p:get_digest())
  376. end
  377. end
  378. if #digests > 0 then
  379. return table.concat(digests, '')
  380. end
  381. return nil
  382. end,
  383. },
  384. ['files'] = {
  385. ['get_value'] = function(task)
  386. local parts = task:get_parts() or E
  387. local files = {}
  388. for _,p in ipairs(parts) do
  389. local fname = p:get_filename()
  390. if fname then
  391. table.insert(files, fname)
  392. end
  393. end
  394. if #files > 0 then
  395. return table.concat(files, ':')
  396. end
  397. return nil
  398. end,
  399. },
  400. }
  401. local function gen_rate_key(task, rtype, bucket)
  402. local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))}
  403. local key_keywords = lua_util.str_split(rtype, '_')
  404. local have_user = false
  405. for _, v in ipairs(key_keywords) do
  406. local ret
  407. if keywords[v] and type(keywords[v]['get_value']) == 'function' then
  408. ret = keywords[v]['get_value'](task)
  409. end
  410. if not ret then return nil end
  411. if v == 'user' then have_user = true end
  412. if type(ret) ~= 'string' then ret = tostring(ret) end
  413. table.insert(key_t, ret)
  414. end
  415. if have_user and not task:get_user() then
  416. return nil
  417. end
  418. return table.concat(key_t, ":")
  419. end
  420. local function make_prefix(redis_key, name, bucket)
  421. local hash_len = 24
  422. if hash_len > #redis_key then hash_len = #redis_key end
  423. local hash = settings.prefix ..
  424. string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)
  425. -- Fill defaults
  426. if not bucket.spam_factor_rate then
  427. bucket.spam_factor_rate = settings.spam_factor_rate
  428. end
  429. if not bucket.ham_factor_rate then
  430. bucket.ham_factor_rate = settings.ham_factor_rate
  431. end
  432. if not bucket.spam_factor_burst then
  433. bucket.spam_factor_burst = settings.spam_factor_burst
  434. end
  435. if not bucket.ham_factor_burst then
  436. bucket.ham_factor_burst = settings.ham_factor_burst
  437. end
  438. return {
  439. bucket = bucket,
  440. name = name,
  441. hash = hash
  442. }
  443. end
  444. local function limit_to_prefixes(task, k, v, prefixes)
  445. local n = 0
  446. for _,bucket in ipairs(v.buckets) do
  447. if v.selector then
  448. local selectors = lua_selectors.process_selectors(task, v.selector)
  449. if selectors then
  450. local combined = lua_selectors.combine_selectors(task, selectors, ':')
  451. if type(combined) == 'string' then
  452. prefixes[combined] = make_prefix(combined, k, bucket)
  453. n = n + 1
  454. else
  455. fun.each(function(p)
  456. prefixes[p] = make_prefix(p, k, bucket)
  457. n = n + 1
  458. end, combined)
  459. end
  460. end
  461. else
  462. local prefix = gen_rate_key(task, k, bucket)
  463. if prefix then
  464. if type(prefix) == 'string' then
  465. prefixes[prefix] = make_prefix(prefix, k, bucket)
  466. n = n + 1
  467. else
  468. fun.each(function(p)
  469. prefixes[p] = make_prefix(p, k, bucket)
  470. n = n + 1
  471. end, prefix)
  472. end
  473. end
  474. end
  475. end
  476. return n
  477. end
  478. local function ratelimit_cb(task)
  479. if not settings.allow_local and
  480. rspamd_lua_utils.is_rspamc_or_controller(task) then return end
  481. -- Get initial task data
  482. local ip = task:get_from_ip()
  483. if ip and ip:is_valid() and settings.whitelisted_ip then
  484. if settings.whitelisted_ip:get_key(ip) then
  485. -- Do not check whitelisted ip
  486. rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP')
  487. return
  488. end
  489. end
  490. -- Parse all rcpts
  491. local rcpts = task:get_recipients()
  492. local rcpts_user = {}
  493. if rcpts then
  494. fun.each(function(r)
  495. fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'})
  496. end, rcpts)
  497. if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then
  498. rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')
  499. return
  500. end
  501. end
  502. -- Get user (authuser)
  503. if settings.whitelisted_user then
  504. local auser = task:get_user()
  505. if settings.whitelisted_user:get_key(auser) then
  506. rspamd_logger.infox(task, 'skip ratelimit for whitelisted user')
  507. return
  508. end
  509. end
  510. -- Now create all ratelimit prefixes
  511. local prefixes = {}
  512. local nprefixes = 0
  513. for k,v in pairs(settings.limits) do
  514. nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)
  515. end
  516. for k, hdl in pairs(settings.custom_keywords or E) do
  517. local ret, redis_key, bd = pcall(hdl, task)
  518. if ret then
  519. local bucket = parse_limit(k, bd)
  520. if bucket then
  521. prefixes[redis_key] = make_prefix(redis_key, k, bucket)
  522. end
  523. nprefixes = nprefixes + 1
  524. else
  525. rspamd_logger.errx(task, 'cannot call handler for %s: %s',
  526. k, redis_key)
  527. end
  528. end
  529. local function gen_check_cb(prefix, bucket, lim_name, lim_key)
  530. return function(err, data)
  531. if err then
  532. rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data)
  533. elseif type(data) == 'table' and data[1] then
  534. lua_util.debugm(N, task,
  535. "got reply for limit %s (%s / %s); %s burst, %s:%s dyn, %s leaked",
  536. prefix, bucket.burst, bucket.rate,
  537. data[2], data[3], data[4], data[5])
  538. if data[1] == 1 then
  539. -- set symbol only and do NOT soft reject
  540. if bucket.symbol then
  541. -- Per bucket symbol
  542. task:insert_result(bucket.symbol, 1.0,
  543. string.format('%s(%s)', lim_name, lim_key))
  544. else
  545. if settings.symbol then
  546. task:insert_result(settings.symbol, 1.0,
  547. string.format('%s(%s)', lim_name, lim_key))
  548. elseif settings.info_symbol then
  549. task:insert_result(settings.info_symbol, 1.0,
  550. string.format('%s(%s)', lim_name, lim_key))
  551. end
  552. end
  553. rspamd_logger.infox(task,
  554. 'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn); redis key: %s',
  555. lim_name, prefix,
  556. bucket.burst, bucket.rate,
  557. data[2], data[3], data[4], lim_key)
  558. if not settings.symbol and not bucket.skip_soft_reject then
  559. if not bucket.message then
  560. task:set_pre_result('soft reject',
  561. message_func(task, lim_name, prefix, bucket, lim_key), N)
  562. else
  563. task:set_pre_result('soft reject', bucket.message)
  564. end
  565. end
  566. end
  567. end
  568. end
  569. end
  570. -- Don't do anything if pre-result has been already set
  571. if task:has_pre_result() then return end
  572. local _,nrcpt = task:has_recipients('smtp')
  573. if not nrcpt or nrcpt <= 0 then
  574. nrcpt = 1
  575. end
  576. if nprefixes > 0 then
  577. -- Save prefixes to the cache to allow update
  578. task:cache_set('ratelimit_prefixes', prefixes)
  579. local now = rspamd_util.get_time()
  580. now = lua_util.round(now * 1000.0) -- Get milliseconds
  581. -- Now call check script for all defined prefixes
  582. for pr,value in pairs(prefixes) do
  583. local bucket = value.bucket
  584. local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms
  585. local bincr = nrcpt
  586. if bucket.skip_recipients then bincr = 1 end
  587. lua_util.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",
  588. value.name, pr, value.hash, bucket.burst, bucket.rate)
  589. lua_redis.exec_redis_script(bucket_check_id,
  590. {key = value.hash, task = task, is_write = true},
  591. gen_check_cb(pr, bucket, value.name, value.hash),
  592. {value.hash, tostring(now), tostring(rate), tostring(bucket.burst),
  593. tostring(settings.expire), tostring(bincr)})
  594. end
  595. end
  596. end
  597. local function ratelimit_update_cb(task)
  598. if task:has_flag('skip') then return end
  599. if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
  600. local prefixes = task:cache_get('ratelimit_prefixes')
  601. if prefixes then
  602. if task:has_pre_result() then
  603. -- Already rate limited/greylisted, do nothing
  604. lua_util.debugm(N, task, 'pre-action has been set, do not update')
  605. return
  606. end
  607. local verdict = lua_verdict.get_specific_verdict(N, task)
  608. local _,nrcpt = task:has_recipients('smtp')
  609. if not nrcpt or nrcpt <= 0 then
  610. nrcpt = 1
  611. end
  612. -- Update each bucket
  613. for k, v in pairs(prefixes) do
  614. local bucket = v.bucket
  615. local function update_bucket_cb(err, data)
  616. if err then
  617. rspamd_logger.errx(task, 'cannot update rate bucket %s: %s',
  618. k, err)
  619. else
  620. lua_util.debugm(N, task,
  621. "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s",
  622. v.name, k, v.hash,
  623. bucket.burst, bucket.rate,
  624. data[1], data[2], data[3])
  625. end
  626. end
  627. local now = rspamd_util.get_time()
  628. now = lua_util.round(now * 1000.0) -- Get milliseconds
  629. local mult_burst = 1.0
  630. local mult_rate = 1.0
  631. if verdict == 'spam' or verdict == 'junk' then
  632. mult_burst = bucket.spam_factor_burst or 1.0
  633. mult_rate = bucket.spam_factor_rate or 1.0
  634. elseif verdict == 'ham' then
  635. mult_burst = bucket.ham_factor_burst or 1.0
  636. mult_rate = bucket.ham_factor_rate or 1.0
  637. end
  638. local bincr = nrcpt
  639. if bucket.skip_recipients then bincr = 1 end
  640. lua_redis.exec_redis_script(bucket_update_id,
  641. {key = v.hash, task = task, is_write = true},
  642. update_bucket_cb,
  643. {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),
  644. tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),
  645. tostring(settings.expire), tostring(bincr)})
  646. end
  647. end
  648. end
  649. local opts = rspamd_config:get_all_opt(N)
  650. if opts then
  651. settings = lua_util.override_defaults(settings, opts)
  652. if opts['limit'] then
  653. rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported')
  654. end
  655. if opts['rates'] and type(opts['rates']) == 'table' then
  656. -- new way of setting limits
  657. fun.each(function(t, lim)
  658. local buckets = {}
  659. if type(lim) == 'table' and lim.bucket then
  660. if lim.bucket[1] then
  661. for _,bucket in ipairs(lim.bucket) do
  662. local b = parse_limit(t, bucket)
  663. if not b then
  664. rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
  665. t, b)
  666. return
  667. end
  668. table.insert(buckets, b)
  669. end
  670. else
  671. local bucket = parse_limit(t, lim.bucket)
  672. if not bucket then
  673. rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
  674. t, lim.bucket)
  675. return
  676. end
  677. buckets = {bucket}
  678. end
  679. settings.limits[t] = {
  680. buckets = buckets
  681. }
  682. if lim.selector then
  683. local selector = lua_selectors.parse_selector(rspamd_config, lim.selector)
  684. if not selector then
  685. rspamd_logger.errx(rspamd_config, 'bad ratelimit selector for %s: "%s"',
  686. t, lim.selector)
  687. settings.limits[t] = nil
  688. return
  689. end
  690. settings.limits[t].selector = selector
  691. end
  692. else
  693. rspamd_logger.warnx(rspamd_config, 'old syntax for ratelimits: %s', lim)
  694. buckets = parse_limit(t, lim)
  695. if buckets then
  696. settings.limits[t] = {
  697. buckets = {buckets}
  698. }
  699. end
  700. end
  701. end, opts['rates'])
  702. end
  703. -- Display what's enabled
  704. fun.each(function(s)
  705. rspamd_logger.infox(rspamd_config, 'enabled ratelimit: %s', s)
  706. end, fun.map(function(n,d)
  707. return string.format('%s [%s]', n,
  708. table.concat(fun.totable(fun.map(function(v)
  709. return string.format('symbol: %s, %s msgs burst, %s msgs/sec rate',
  710. v.symbol, v.burst, v.rate)
  711. end, d.buckets)), '; ')
  712. )
  713. end, settings.limits))
  714. -- Ret, ret, ret: stupid legacy stuff:
  715. -- If we have a string with commas then load it as as static map
  716. -- otherwise, apply normal logic of Rspamd maps
  717. local wrcpts = opts['whitelisted_rcpts']
  718. if type(wrcpts) == 'string' then
  719. if string.find(wrcpts, ',') then
  720. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
  721. lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')
  722. else
  723. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
  724. 'Ratelimit whitelisted rcpts')
  725. end
  726. elseif type(opts['whitelisted_rcpts']) == 'table' then
  727. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
  728. 'Ratelimit whitelisted rcpts')
  729. else
  730. -- Stupid default...
  731. settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
  732. settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts')
  733. end
  734. if opts['whitelisted_ip'] then
  735. settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix',
  736. 'Ratelimit whitelist ip map')
  737. end
  738. if opts['whitelisted_user'] then
  739. settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set',
  740. 'Ratelimit whitelist user map')
  741. end
  742. settings.custom_keywords = {}
  743. if opts['custom_keywords'] then
  744. local ret, res_or_err = pcall(loadfile(opts['custom_keywords']))
  745. if ret then
  746. opts['custom_keywords'] = {}
  747. if type(res_or_err) == 'table' then
  748. for k,hdl in pairs(res_or_err) do
  749. settings['custom_keywords'][k] = hdl
  750. end
  751. elseif type(res_or_err) == 'function' then
  752. settings['custom_keywords']['custom'] = res_or_err
  753. end
  754. else
  755. rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s',
  756. opts['custom_keywords'], res_or_err)
  757. settings['custom_keywords'] = {}
  758. end
  759. end
  760. if opts['message_func'] then
  761. message_func = assert(load(opts['message_func']))()
  762. end
  763. redis_params = lua_redis.parse_redis_server('ratelimit')
  764. if not redis_params then
  765. rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')
  766. lua_util.disable_module(N, "redis")
  767. else
  768. local s = {
  769. type = settings.prefilter and 'prefilter' or 'callback',
  770. name = 'RATELIMIT_CHECK',
  771. priority = 7,
  772. callback = ratelimit_cb,
  773. flags = 'empty,nostat',
  774. }
  775. local id = rspamd_config:register_symbol(s)
  776. -- Register per bucket symbols
  777. -- Display what's enabled
  778. fun.each(function(set, lim)
  779. if type(lim.buckets) == 'table' then
  780. for _,b in ipairs(lim.buckets) do
  781. if b.symbol then
  782. rspamd_config:register_symbol{
  783. type = 'virtual',
  784. name = b.symbol,
  785. score = 0.0,
  786. parent = id
  787. }
  788. end
  789. end
  790. end
  791. end, settings.limits)
  792. if settings.info_symbol then
  793. rspamd_config:register_symbol{
  794. type = 'virtual',
  795. name = settings.info_symbol,
  796. score = 0.0,
  797. parent = id
  798. }
  799. end
  800. if settings.symbol then
  801. rspamd_config:register_symbol{
  802. type = 'virtual',
  803. name = settings.symbol,
  804. score = 0.0, -- Might be overridden if needed
  805. parent = id
  806. }
  807. end
  808. rspamd_config:register_symbol {
  809. type = 'idempotent',
  810. name = 'RATELIMIT_UPDATE',
  811. callback = ratelimit_update_cb,
  812. }
  813. end
  814. end
  815. rspamd_config:add_on_load(function(cfg, ev_base, worker)
  816. load_scripts(cfg, ev_base)
  817. end)