You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_util.lua 43KB


  1. --[[
  2. Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_util
  15. -- This module contains utility functions for working with Lua and/or Rspamd
  16. --]]
  17. local exports = {}
  18. local lpeg = require 'lpeg'
  19. local rspamd_util = require "rspamd_util"
  20. local fun = require "fun"
  21. local lupa = require "lupa"
  22. local split_grammar = {}
  23. local spaces_split_grammar
  24. local space = lpeg.S ' \t\n\v\f\r'
  25. local nospace = 1 - space
  26. local ptrim = space ^ 0 * lpeg.C((space ^ 0 * nospace ^ 1) ^ 0)
  27. local match = lpeg.match
  28. local function shallowcopy(orig)
  29. local orig_type = type(orig)
  30. local copy
  31. if orig_type == 'table' then
  32. copy = {}
  33. for orig_key, orig_value in pairs(orig) do
  34. copy[orig_key] = orig_value
  35. end
  36. else
  37. copy = orig
  38. end
  39. return copy
  40. end
  41. local function deepcopy(orig)
  42. local orig_type = type(orig)
  43. local copy
  44. if orig_type == 'table' then
  45. copy = {}
  46. for orig_key, orig_value in next, orig, nil do
  47. copy[deepcopy(orig_key)] = deepcopy(orig_value)
  48. end
  49. if getmetatable(orig) then
  50. setmetatable(copy, deepcopy(getmetatable(orig)))
  51. end
  52. else
  53. -- number, string, boolean, etc
  54. copy = orig
  55. end
  56. return copy
  57. end
  58. lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', {
  59. keep_trailing_newline = true,
  60. autoescape = false,
  61. })
  62. lupa.filters.pbkdf = function(s)
  63. local cr = require "rspamd_cryptobox"
  64. return cr.pbkdf(s)
  65. end
  66. -- Dirty hacks to avoid shared state
  67. package.loaded['lupa'] = nil
  68. local lupa_orig = require "lupa"
  69. local function rspamd_str_split(s, sep)
  70. local gr
  71. if not sep then
  72. if not spaces_split_grammar then
  73. local _sep = space
  74. local elem = lpeg.C((1 - _sep) ^ 0)
  75. local p = lpeg.Ct(elem * (_sep * elem) ^ 0)
  76. spaces_split_grammar = p
  77. end
  78. gr = spaces_split_grammar
  79. else
  80. gr = split_grammar[sep]
  81. if not gr then
  82. local _sep
  83. if type(sep) == 'string' then
  84. _sep = lpeg.S(sep) -- Assume set
  85. else
  86. _sep = sep -- Assume lpeg object
  87. end
  88. local elem = lpeg.C((1 - _sep) ^ 0)
  89. local p = lpeg.Ct(elem * (_sep * elem) ^ 0)
  90. gr = p
  91. split_grammar[sep] = gr
  92. end
  93. end
  94. return gr:match(s)
  95. end
  96. --[[[
  97. -- @function lua_util.str_split(text, delimiter)
  98. -- Splits text into a numeric table by delimiter
  99. -- @param {string} text delimited text
  100. -- @param {string} delimiter the delimiter
  101. -- @return {table} numeric table containing string parts
  102. --]]
  103. exports.rspamd_str_split = rspamd_str_split
  104. exports.str_split = rspamd_str_split
  105. local function rspamd_str_trim(s)
  106. return match(ptrim, s)
  107. end
  108. exports.rspamd_str_trim = rspamd_str_trim
  109. --[[[
  110. -- @function lua_util.str_trim(text)
  111. -- Returns a string with no trailing and leading spaces
  112. -- @param {string} text input text
  113. -- @return {string} string with no trailing and leading spaces
  114. --]]
  115. exports.str_trim = rspamd_str_trim
  116. --[[[
  117. -- @function lua_util.str_startswith(text, prefix)
  118. -- @param {string} text
  119. -- @param {string} prefix
  120. -- @return {boolean} true if text starts with the specified prefix, false otherwise
  121. --]]
  122. exports.str_startswith = function(s, prefix)
  123. return s:sub(1, prefix:len()) == prefix
  124. end
  125. --[[[
  126. -- @function lua_util.str_endswith(text, suffix)
  127. -- @param {string} text
  128. -- @param {string} suffix
  129. -- @return {boolean} true if text ends with the specified suffix, false otherwise
  130. --]]
  131. exports.str_endswith = function(s, suffix)
  132. return s:find(suffix, -suffix:len(), true) ~= nil
  133. end
  134. --[[[
  135. -- @function lua_util.round(number, decimalPlaces)
  136. -- Round number to fixed number of decimal points
  137. -- @param {number} number number to round
  138. -- @param {number} decimalPlaces number of decimal points
  139. -- @return {number} rounded number
  140. --]]
  141. -- modified version from Robert Jay Gould http://lua-users.org/wiki/SimpleRound
  142. exports.round = function(num, numDecimalPlaces)
  143. local mult = 10 ^ (numDecimalPlaces or 0)
  144. if num >= 0 then
  145. return math.floor(num * mult + 0.5) / mult
  146. else
  147. return math.ceil(num * mult - 0.5) / mult
  148. end
  149. end
  150. --[[[
  151. -- @function lua_util.template(text, replacements)
  152. -- Replaces values in a text template
  153. -- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces.
  154. -- @param {string} text text containing variables
  155. -- @param {table} replacements key/value pairs for replacements
  156. -- @return {string} string containing replaced values
  157. -- @example
  158. -- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  159. -- -- goop contains "HELLO LUA WORLD!"
  160. --]]
  161. exports.template = function(tmpl, keys)
  162. local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
  163. local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit ^ 1) / keys) }
  164. local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit ^ 1) / keys) * (lpeg.P("}") / "") }
  165. local template_grammar = lpeg.Cs((var + var_braced + 1) ^ 0)
  166. return lpeg.match(template_grammar, tmpl)
  167. end
  168. local function enrich_template_with_globals(env)
  169. local newenv = shallowcopy(env)
  170. newenv.paths = rspamd_paths
  171. newenv.env = rspamd_env
  172. return newenv
  173. end
  174. --[[[
  175. -- @function lua_util.jinja_template(text, env[, skip_global_env][, is_orig][, custom_filters])
  176. -- Replaces values in a text template according to jinja2 syntax
  177. -- @param {string} text text containing variables
  178. -- @param {table} replacements key/value pairs for replacements
  179. -- @param {boolean} skip_global_env don't export Rspamd superglobals
  180. -- @param {boolean} is_orig use the original lupa configuration with {% raw %}`{{`{% endraw %} for variables
  181. -- @param {table} custom_filters custom filters to use (or nil if not needed)
  182. -- @return {string} string containing replaced values
  183. -- @example
  184. -- lua_util.jinja_template("HELLO {=FOO=} {=BAR=}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  185. -- "HELLO LUA WORLD!"
  186. --]]
  187. exports.jinja_template = function(text, env, skip_global_env, is_orig, custom_filters)
  188. local lupa_to_use = is_orig and lupa_orig or lupa
  189. if not skip_global_env then
  190. env = enrich_template_with_globals(env)
  191. end
  192. local orig_filters = {}
  193. if type(custom_filters) == 'table' then
  194. for k, v in pairs(custom_filters) do
  195. orig_filters[k] = lupa_to_use.filters[k]
  196. lupa_to_use.filters[k] = v
  197. end
  198. end
  199. local result = lupa_to_use.expand(text, env)
  200. -- Restore custom filters
  201. if type(custom_filters) == 'table' then
  202. for k, _ in pairs(custom_filters) do
  203. lupa_to_use.filters[k] = orig_filters[k]
  204. end
  205. end
  206. return result
  207. end
  208. --[[[
  209. -- @function lua_util.jinja_file(filename, env[, skip_global_env][, is_orig][, custom_filters])
  210. -- Replaces values in a text template according to jinja2 syntax
  211. -- @param {string} filename name of file to expand
  212. -- @param {table} replacements key/value pairs for replacements
  213. -- @param {boolean} skip_global_env don't export Rspamd superglobals
  214. -- @param {boolean} is_orig use the original lupa configuration with {% raw %}`{{`{% endraw %} for variables
  215. -- @param {table} custom_filters custom filters to use (or nil if not needed)
  216. -- @return {string} string containing replaced values
  217. -- @example
  218. -- lua_util.jinja_template("HELLO {=FOO=} {=BAR=}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  219. -- "HELLO LUA WORLD!"
  220. --]]
  221. exports.jinja_template_file = function(filename, env, skip_global_env, is_orig, custom_filters)
  222. local lupa_to_use = is_orig and lupa_orig or lupa
  223. if not skip_global_env then
  224. env = enrich_template_with_globals(env)
  225. end
  226. local orig_filters = {}
  227. if type(custom_filters) == 'table' then
  228. for k, v in pairs(custom_filters) do
  229. orig_filters[k] = lupa_to_use.filters[k]
  230. lupa_to_use.filters[k] = v
  231. end
  232. end
  233. local result = lupa_to_use.expand_file(filename, env)
  234. -- Restore custom filters
  235. if type(custom_filters) == 'table' then
  236. for k, _ in pairs(custom_filters) do
  237. lupa_to_use.filters[k] = orig_filters[k]
  238. end
  239. end
  240. return result
  241. end
  242. exports.remove_email_aliases = function(email_addr)
  243. local function check_gmail_user(addr)
  244. -- Remove all points
  245. local no_dots_user = string.gsub(addr.user, '%.', '')
  246. local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$')
  247. if cap then
  248. return cap, rspamd_str_split(pluses, '+'), nil
  249. elseif no_dots_user ~= addr.user then
  250. return no_dots_user, {}, nil
  251. end
  252. return nil
  253. end
  254. local function check_address(addr)
  255. if addr.user then
  256. local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$')
  257. if cap then
  258. return cap, rspamd_str_split(pluses, '+'), nil
  259. end
  260. end
  261. return nil
  262. end
  263. local function set_addr(addr, new_user, new_domain)
  264. if new_user then
  265. addr.user = new_user
  266. end
  267. if new_domain then
  268. addr.domain = new_domain
  269. end
  270. if addr.domain then
  271. addr.addr = string.format('%s@%s', addr.user, addr.domain)
  272. else
  273. addr.addr = string.format('%s@', addr.user)
  274. end
  275. if addr.name and #addr.name > 0 then
  276. addr.raw = string.format('"%s" <%s>', addr.name, addr.addr)
  277. else
  278. addr.raw = string.format('<%s>', addr.addr)
  279. end
  280. end
  281. local function check_gmail(addr)
  282. local nu, tags, nd = check_gmail_user(addr)
  283. if nu then
  284. return nu, tags, nd
  285. end
  286. return nil
  287. end
  288. local function check_googlemail(addr)
  289. local nd = 'gmail.com'
  290. local nu, tags = check_gmail_user(addr)
  291. if nu then
  292. return nu, tags, nd
  293. end
  294. return nil, nil, nd
  295. end
  296. local specific_domains = {
  297. ['gmail.com'] = check_gmail,
  298. ['googlemail.com'] = check_googlemail,
  299. }
  300. if email_addr then
  301. if email_addr.domain and specific_domains[email_addr.domain] then
  302. local nu, tags, nd = specific_domains[email_addr.domain](email_addr)
  303. if nu or nd then
  304. set_addr(email_addr, nu, nd)
  305. return nu, tags
  306. end
  307. else
  308. local nu, tags, nd = check_address(email_addr)
  309. if nu or nd then
  310. set_addr(email_addr, nu, nd)
  311. return nu, tags
  312. end
  313. end
  314. return nil
  315. end
  316. end
  317. exports.is_rspamc_or_controller = function(task)
  318. local ua = task:get_request_header('User-Agent') or ''
  319. local pwd = task:get_request_header('Password')
  320. local is_rspamc = false
  321. if tostring(ua) == 'rspamc' or pwd then
  322. is_rspamc = true
  323. end
  324. return is_rspamc
  325. end
  326. --[[[
  327. -- @function lua_util.unpack(table)
  328. -- Converts numeric table to varargs
  329. -- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3
  330. -- @param {table} table numerically indexed table to unpack
  331. -- @return {varargs} unpacked table elements
  332. --]]
  333. local unpack_function = table.unpack or unpack
  334. exports.unpack = function(t)
  335. return unpack_function(t)
  336. end
  337. --[[[
  338. -- @function lua_util.flatten(table)
  339. -- Flatten underlying tables in a single table
  340. -- @param {table} table table of tables
  341. -- @return {table} flattened table
  342. --]]
  343. exports.flatten = function(t)
  344. local res = {}
  345. for _, e in fun.iter(t) do
  346. for _, v in fun.iter(e) do
  347. res[#res + 1] = v
  348. end
  349. end
  350. return res
  351. end
  352. --[[[
  353. -- @function lua_util.spairs(table)
  354. -- Like `pairs` but keys are sorted lexicographically
  355. -- @param {table} table table containing key/value pairs
  356. -- @return {function} generator function returning key/value pairs
  357. --]]
  358. -- Sorted iteration:
  359. -- for k,v in spairs(t) do ... end
  360. --
  361. -- or with custom comparison:
  362. -- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end)
  363. --
  364. -- optional limit is also available (e.g. return top X elements)
  365. local function spairs(t, order, lim)
  366. -- collect the keys
  367. local keys = {}
  368. for k in pairs(t) do
  369. keys[#keys + 1] = k
  370. end
  371. -- if order function given, sort by it by passing the table and keys a, b,
  372. -- otherwise just sort the keys
  373. if order then
  374. table.sort(keys, function(a, b)
  375. return order(t, a, b)
  376. end)
  377. else
  378. table.sort(keys)
  379. end
  380. -- return the iterator function
  381. local i = 0
  382. return function()
  383. i = i + 1
  384. if not lim or i <= lim then
  385. if keys[i] then
  386. return keys[i], t[keys[i]]
  387. end
  388. end
  389. end
  390. end
  391. exports.spairs = spairs
  392. local lua_cfg_utils = require "lua_cfg_utils"
  393. exports.config_utils = lua_cfg_utils
  394. exports.disable_module = lua_cfg_utils.disable_module
  395. --[[[
  396. -- @function lua_util.disable_module(modname)
  397. -- Checks experimental plugins state and disable if needed
  398. -- @param {string} modname name of plugin to check
  399. -- @return {boolean} true if plugin should be enabled, false otherwise
  400. --]]
  401. local function check_experimental(modname)
  402. if rspamd_config:experimental_enabled() then
  403. return true
  404. else
  405. lua_cfg_utils.disable_module(modname, 'experimental')
  406. end
  407. return false
  408. end
  409. exports.check_experimental = check_experimental
  410. --[[[
  411. -- @function lua_util.list_to_hash(list)
  412. -- Converts numerically-indexed table to table indexed by values
  413. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  414. -- @return {table} table indexed by values
  415. -- @example
  416. -- local h = lua_util.list_to_hash({"a", "b"})
  417. -- -- h contains {a = true, b = true}
  418. --]]
  419. local function list_to_hash(list)
  420. if type(list) == 'table' then
  421. if list[1] then
  422. local h = {}
  423. for _, e in ipairs(list) do
  424. h[e] = true
  425. end
  426. return h
  427. else
  428. return list
  429. end
  430. elseif type(list) == 'string' then
  431. local h = {}
  432. h[list] = true
  433. return h
  434. end
  435. end
  436. exports.list_to_hash = list_to_hash
  437. --[[[
  438. -- @function lua_util.nkeys(table|gen, param, state)
  439. -- Returns number of keys in a table (i.e. from both the array and hash parts combined)
  440. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  441. -- @return {number} number of keys
  442. -- @example
  443. -- print(lua_util.nkeys({})) -- 0
  444. -- print(lua_util.nkeys({ "a", nil, "b" })) -- 2
  445. -- print(lua_util.nkeys({ dog = 3, cat = 4, bird = nil })) -- 2
  446. -- print(lua_util.nkeys({ "a", dog = 3, cat = 4 })) -- 3
  447. --
  448. --]]
  449. local function nkeys(gen, param, state)
  450. local n = 0
  451. if not param then
  452. for _, _ in pairs(gen) do
  453. n = n + 1
  454. end
  455. else
  456. for _, _ in fun.iter(gen, param, state) do
  457. n = n + 1
  458. end
  459. end
  460. return n
  461. end
  462. exports.nkeys = nkeys
  463. --[[[
  464. -- @function lua_util.parse_time_interval(str)
  465. -- Parses human readable time interval
  466. -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days,
  467. -- 'w' for weeks, 'y' for years
  468. -- @param {string} str input string
  469. -- @return {number|nil} parsed interval as seconds (might be fractional)
  470. --]]
  471. local function parse_time_interval(str)
  472. local function parse_time_suffix(s)
  473. if s == 's' then
  474. return 1
  475. elseif s == 'm' then
  476. return 60
  477. elseif s == 'h' then
  478. return 3600
  479. elseif s == 'd' then
  480. return 86400
  481. elseif s == 'w' then
  482. return 86400 * 7
  483. elseif s == 'y' then
  484. return 365 * 86400;
  485. end
  486. end
  487. local digit = lpeg.R("09")
  488. local parser = {}
  489. parser.integer = (lpeg.S("+-") ^ -1) *
  490. (digit ^ 1)
  491. parser.fractional = (lpeg.P(".")) *
  492. (digit ^ 1)
  493. parser.number = (parser.integer *
  494. (parser.fractional ^ -1)) +
  495. (lpeg.S("+-") * parser.fractional)
  496. parser.time = lpeg.Cf(lpeg.Cc(1) *
  497. (parser.number / tonumber) *
  498. ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1),
  499. function(acc, val)
  500. return acc * val
  501. end)
  502. local t = lpeg.match(parser.time, str)
  503. return t
  504. end
  505. exports.parse_time_interval = parse_time_interval
  506. --[[[
  507. -- @function lua_util.dehumanize_number(str)
  508. -- Parses human readable number
  509. -- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier,
  510. -- e.g. `10mb` equal to `10 * 1024 * 1024`
  511. -- @param {string} str input string
  512. -- @return {number|nil} parsed number
  513. --]]
  514. local function dehumanize_number(str)
  515. local function parse_suffix(s)
  516. if s == 'k' then
  517. return 1000
  518. elseif s == 'm' then
  519. return 1000000
  520. elseif s == 'g' then
  521. return 1e9
  522. elseif s == 'kb' then
  523. return 1024
  524. elseif s == 'mb' then
  525. return 1024 * 1024
  526. elseif s == 'gb' then
  527. return 1024 * 1024;
  528. end
  529. end
  530. local digit = lpeg.R("09")
  531. local parser = {}
  532. parser.integer = (lpeg.S("+-") ^ -1) *
  533. (digit ^ 1)
  534. parser.fractional = (lpeg.P(".")) *
  535. (digit ^ 1)
  536. parser.number = (parser.integer *
  537. (parser.fractional ^ -1)) +
  538. (lpeg.S("+-") * parser.fractional)
  539. parser.humanized_number = lpeg.Cf(lpeg.Cc(1) *
  540. (parser.number / tonumber) *
  541. (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1),
  542. function(acc, val)
  543. return acc * val
  544. end)
  545. local t = lpeg.match(parser.humanized_number, str)
  546. return t
  547. end
  548. exports.dehumanize_number = dehumanize_number
  549. --[[[
  550. -- @function lua_util.table_cmp(t1, t2)
  551. -- Compare two tables deeply
  552. --]]
  553. local function table_cmp(table1, table2)
  554. local avoid_loops = {}
  555. local function recurse(t1, t2)
  556. if type(t1) ~= type(t2) then
  557. return false
  558. end
  559. if type(t1) ~= "table" then
  560. return t1 == t2
  561. end
  562. if avoid_loops[t1] then
  563. return avoid_loops[t1] == t2
  564. end
  565. avoid_loops[t1] = t2
  566. -- Copy keys from t2
  567. local t2keys = {}
  568. local t2tablekeys = {}
  569. for k, _ in pairs(t2) do
  570. if type(k) == "table" then
  571. table.insert(t2tablekeys, k)
  572. end
  573. t2keys[k] = true
  574. end
  575. -- Let's iterate keys from t1
  576. for k1, v1 in pairs(t1) do
  577. local v2 = t2[k1]
  578. if type(k1) == "table" then
  579. -- if key is a table, we need to find an equivalent one.
  580. local ok = false
  581. for i, tk in ipairs(t2tablekeys) do
  582. if table_cmp(k1, tk) and recurse(v1, t2[tk]) then
  583. table.remove(t2tablekeys, i)
  584. t2keys[tk] = nil
  585. ok = true
  586. break
  587. end
  588. end
  589. if not ok then
  590. return false
  591. end
  592. else
  593. -- t1 has a key which t2 doesn't have, fail.
  594. if v2 == nil then
  595. return false
  596. end
  597. t2keys[k1] = nil
  598. if not recurse(v1, v2) then
  599. return false
  600. end
  601. end
  602. end
  603. -- if t2 has a key which t1 doesn't have, fail.
  604. if next(t2keys) then
  605. return false
  606. end
  607. return true
  608. end
  609. return recurse(table1, table2)
  610. end
  611. exports.table_cmp = table_cmp
  612. --[[[
  613. -- @function lua_util.table_merge(t1, t2)
  614. -- Merge two tables
  615. --]]
  616. local function table_merge(t1, t2)
  617. local res = {}
  618. local nidx = 1 -- for numeric indicies
  619. local it_func = function(k, v)
  620. if type(k) == 'number' then
  621. res[nidx] = v
  622. nidx = nidx + 1
  623. else
  624. res[k] = v
  625. end
  626. end
  627. for k, v in pairs(t1) do
  628. it_func(k, v)
  629. end
  630. for k, v in pairs(t2) do
  631. it_func(k, v)
  632. end
  633. return res
  634. end
  635. exports.table_merge = table_merge
  636. --[[[
  637. -- @function lua_util.table_cmp(task, name, value, stop_chars)
  638. -- Performs header folding
  639. --]]
  640. exports.fold_header = function(task, name, value, stop_chars)
  641. local how
  642. if task:has_flag("milter") then
  643. how = "lf"
  644. else
  645. how = task:get_newlines_type()
  646. end
  647. return rspamd_util.fold_header(name, value, how, stop_chars)
  648. end
  649. --[[[
  650. -- @function lua_util.override_defaults(defaults, override)
  651. -- Overrides values from defaults with override
  652. --]]
  653. local function override_defaults(def, override)
  654. -- Corner cases
  655. if not override or type(override) ~= 'table' then
  656. return def
  657. end
  658. if not def or type(def) ~= 'table' then
  659. return override
  660. end
  661. local res = {}
  662. for k, v in pairs(override) do
  663. if type(v) == 'table' then
  664. if def[k] and type(def[k]) == 'table' then
  665. -- Recursively override elements
  666. res[k] = override_defaults(def[k], v)
  667. else
  668. res[k] = v
  669. end
  670. else
  671. res[k] = v
  672. end
  673. end
  674. for k, v in pairs(def) do
  675. if type(res[k]) == 'nil' then
  676. res[k] = v
  677. end
  678. end
  679. return res
  680. end
  681. exports.override_defaults = override_defaults
  682. --[[[
  683. -- @function lua_util.filter_specific_urls(urls, params)
  684. -- params: {
  685. - - task - if needed to save in the cache
  686. - - limit <int> (default = 9999)
  687. - - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
  688. works only if number of unique eSLD less than `limit`
  689. - - need_emails <bool> (default = false)
  690. - - filter <callback> (default = nil)
  691. - - prefix <string> cache prefix (default = nil)
  692. -- }
  693. -- Apply heuristic in extracting of urls from `urls` table, this function
  694. -- tries its best to extract specific number of urls from a task based on
  695. -- their characteristics
  696. --]]
  697. exports.filter_specific_urls = function(urls, params)
  698. local cache_key
  699. if params.task and not params.no_cache then
  700. if params.prefix then
  701. cache_key = params.prefix
  702. else
  703. cache_key = string.format('sp_urls_%d%s%s%s', params.limit,
  704. tostring(params.need_emails or false),
  705. tostring(params.need_images or false),
  706. tostring(params.need_content or false))
  707. end
  708. local cached = params.task:cache_get(cache_key)
  709. if cached then
  710. return cached
  711. end
  712. end
  713. if not urls then
  714. return {}
  715. end
  716. if params.filter then
  717. urls = fun.totable(fun.filter(params.filter, urls))
  718. end
  719. -- Filter by tld:
  720. local tlds = {}
  721. local eslds = {}
  722. local ntlds, neslds = 0, 0
  723. local res = {}
  724. local nres = 0
  725. local function insert_url(str, u)
  726. if not res[str] then
  727. res[str] = u
  728. nres = nres + 1
  729. return true
  730. end
  731. return false
  732. end
  733. local function process_single_url(u, default_priority)
  734. local priority = default_priority or 1 -- Normal priority
  735. local flags = u:get_flags()
  736. if params.ignore_ip and flags.numeric then
  737. return
  738. end
  739. if flags.redirected then
  740. local redir = u:get_redirected() -- get the real url
  741. if params.ignore_redirected then
  742. -- Replace `u` with redir
  743. u = redir
  744. priority = 2
  745. else
  746. -- Process both redirected url and the original one
  747. process_single_url(redir, 2)
  748. end
  749. end
  750. if flags.image then
  751. if not params.need_images then
  752. -- Ignore url
  753. return
  754. else
  755. -- Penalise images in urls
  756. priority = 0
  757. end
  758. end
  759. local esld = u:get_tld()
  760. local str_hash = tostring(u)
  761. if esld then
  762. -- Special cases
  763. if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then
  764. if flags.obscured then
  765. priority = 3
  766. else
  767. if (flags.has_user or flags.has_port) then
  768. priority = 2
  769. elseif (flags.subject or flags.phished) then
  770. priority = 2
  771. end
  772. end
  773. elseif flags.html_displayed then
  774. priority = 0
  775. end
  776. if not eslds[esld] then
  777. eslds[esld] = { { str_hash, u, priority } }
  778. neslds = neslds + 1
  779. else
  780. if #eslds[esld] < params.esld_limit then
  781. table.insert(eslds[esld], { str_hash, u, priority })
  782. end
  783. end
  784. -- eSLD - 1 part => tld
  785. local parts = rspamd_str_split(esld, '.')
  786. local tld = table.concat(fun.totable(fun.tail(parts)), '.')
  787. if not tlds[tld] then
  788. tlds[tld] = { { str_hash, u, priority } }
  789. ntlds = ntlds + 1
  790. else
  791. table.insert(tlds[tld], { str_hash, u, priority })
  792. end
  793. end
  794. end
  795. for _, u in ipairs(urls) do
  796. process_single_url(u)
  797. end
  798. local limit = params.limit
  799. limit = limit - nres
  800. if limit < 0 then
  801. limit = 0
  802. end
  803. if limit == 0 then
  804. res = exports.values(res)
  805. if params.task and not params.no_cache then
  806. params.task:cache_set(cache_key, res)
  807. end
  808. return res
  809. end
  810. -- Sort eSLDs and tlds
  811. local function sort_stuff(tbl)
  812. -- Sort according to max priority
  813. table.sort(tbl, function(e1, e2)
  814. -- Sort by priority so max priority is at the end
  815. table.sort(e1, function(tr1, tr2)
  816. return tr1[3] < tr2[3]
  817. end)
  818. table.sort(e2, function(tr1, tr2)
  819. return tr1[3] < tr2[3]
  820. end)
  821. if e1[#e1][3] ~= e2[#e2][3] then
  822. -- Sort by priority so max priority is at the beginning
  823. return e1[#e1][3] > e2[#e2][3]
  824. else
  825. -- Prefer less urls to more urls per esld
  826. return #e1 < #e2
  827. end
  828. end)
  829. return tbl
  830. end
  831. eslds = sort_stuff(exports.values(eslds))
  832. neslds = #eslds
  833. if neslds <= limit then
  834. -- Number of eslds < limit
  835. repeat
  836. local item_found = false
  837. for _, lurls in ipairs(eslds) do
  838. if #lurls > 0 then
  839. local last = table.remove(lurls)
  840. insert_url(last[1], last[2])
  841. limit = limit - 1
  842. item_found = true
  843. end
  844. end
  845. until limit <= 0 or not item_found
  846. res = exports.values(res)
  847. if params.task and not params.no_cache then
  848. params.task:cache_set(cache_key, res)
  849. end
  850. return res
  851. end
  852. tlds = sort_stuff(exports.values(tlds))
  853. ntlds = #tlds
  854. -- Number of tlds < limit
  855. while limit > 0 do
  856. for _, lurls in ipairs(tlds) do
  857. if #lurls > 0 then
  858. local last = table.remove(lurls)
  859. insert_url(last[1], last[2])
  860. limit = limit - 1
  861. end
  862. if limit == 0 then
  863. break
  864. end
  865. end
  866. end
  867. res = exports.values(res)
  868. if params.task and not params.no_cache then
  869. params.task:cache_set(cache_key, res)
  870. end
  871. return res
  872. end
  873. --[[[
  874. -- @function lua_util.extract_specific_urls(params)
  875. -- params: {
  876. - - task
  877. - - limit <int> (default = 9999)
  878. - - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
  879. works only if number of unique eSLD less than `limit`
  880. - - need_emails <bool> (default = false)
  881. - - filter <callback> (default = nil)
  882. - - prefix <string> cache prefix (default = nil)
  883. - - ignore_redirected <bool> (default = false)
  884. - - need_images <bool> (default = false)
  885. - - need_content <bool> (default = false)
  886. -- }
  887. -- Apply heuristic in extracting of urls from task, this function
  888. -- tries its best to extract specific number of urls from a task based on
  889. -- their characteristics
  890. --]]
  891. -- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
  892. exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix)
  893. local default_params = {
  894. limit = 9999,
  895. esld_limit = 9999,
  896. need_emails = false,
  897. need_images = false,
  898. need_content = false,
  899. filter = nil,
  900. prefix = nil,
  901. ignore_ip = false,
  902. ignore_redirected = false,
  903. no_cache = false,
  904. }
  905. local params
  906. if type(params_or_task) == 'table' and type(lim) == 'nil' then
  907. params = params_or_task
  908. else
  909. -- Deprecated call
  910. params = {
  911. task = params_or_task,
  912. limit = lim,
  913. need_emails = need_emails,
  914. filter = filter,
  915. prefix = prefix
  916. }
  917. end
  918. for k, v in pairs(default_params) do
  919. if type(params[k]) == 'nil' and v ~= nil then
  920. params[k] = v
  921. end
  922. end
  923. local url_params = {
  924. emails = params.need_emails,
  925. images = params.need_images,
  926. content = params.need_content,
  927. flags = params.flags, -- maybe nil
  928. flags_mode = params.flags_mode, -- maybe nil
  929. }
  930. -- Shortcut for cached stuff
  931. if params.task and not params.no_cache then
  932. local cache_key
  933. if params.prefix then
  934. cache_key = params.prefix
  935. else
  936. local cache_key_suffix
  937. if params.flags then
  938. cache_key_suffix = table.concat(params.flags) .. (params.flags_mode or '')
  939. else
  940. cache_key_suffix = string.format('%s%s%s',
  941. tostring(params.need_emails or false),
  942. tostring(params.need_images or false),
  943. tostring(params.need_content or false))
  944. end
  945. cache_key = string.format('sp_urls_%d%s', params.limit, cache_key_suffix)
  946. end
  947. local cached = params.task:cache_get(cache_key)
  948. if cached then
  949. return cached
  950. end
  951. end
  952. -- No cache version
  953. local urls = params.task:get_urls(url_params)
  954. return exports.filter_specific_urls(urls, params)
  955. end
  956. --[[[
  957. -- @function lua_util.deepcopy(table)
  958. -- params: {
  959. - - table
  960. -- }
  961. -- Performs deep copy of the table. Including metatables
  962. --]]
  963. exports.deepcopy = deepcopy
  964. --[[[
  965. -- @function lua_util.deepsort(table)
  966. -- params: {
  967. - - table
  968. -- }
  969. -- Performs recursive in-place sort of a table
  970. --]]
  971. local function default_sort_cmp(e1, e2)
  972. if type(e1) == type(e2) then
  973. return e1 < e2
  974. else
  975. return type(e1) < type(e2)
  976. end
  977. end
  978. local function deepsort(tbl, sort_func)
  979. local orig_type = type(tbl)
  980. if orig_type == 'table' then
  981. table.sort(tbl, sort_func or default_sort_cmp)
  982. for _, orig_value in next, tbl, nil do
  983. deepsort(orig_value)
  984. end
  985. end
  986. end
  987. exports.deepsort = deepsort
  988. --[[[
  989. -- @function lua_util.shallowcopy(tbl)
  990. -- Performs shallow (and fast) copy of a table or another Lua type
  991. --]]
  992. exports.shallowcopy = shallowcopy
  993. -- Debugging support
  994. local logger = require "rspamd_logger"
  995. local unconditional_debug = logger.log_level() == 'debug'
  996. local debug_modules = {}
  997. local debug_aliases = {}
  998. local log_level = 384 -- debug + forced (1 << 7 | 1 << 8)
  999. exports.init_debug_logging = function(config)
  1000. -- Fill debug modules from the config
  1001. if not unconditional_debug then
  1002. local log_config = config:get_all_opt('logging')
  1003. if log_config then
  1004. local log_level_str = log_config.level
  1005. if log_level_str then
  1006. if log_level_str == 'debug' then
  1007. unconditional_debug = true
  1008. end
  1009. end
  1010. if log_config.debug_modules then
  1011. for _, m in ipairs(log_config.debug_modules) do
  1012. debug_modules[m] = true
  1013. logger.infox(config, 'enable debug for Lua module %s', m)
  1014. end
  1015. end
  1016. if #debug_aliases > 0 then
  1017. for alias, mod in pairs(debug_aliases) do
  1018. if debug_modules[mod] then
  1019. debug_modules[alias] = true
  1020. logger.infox(config, 'enable debug for Lua module %s (%s aliased)',
  1021. alias, mod)
  1022. end
  1023. end
  1024. end
  1025. end
  1026. end
  1027. end
  1028. exports.enable_debug_logging = function()
  1029. unconditional_debug = true
  1030. end
  1031. exports.enable_debug_modules = function(...)
  1032. for _, m in ipairs({ ... }) do
  1033. debug_modules[m] = true
  1034. end
  1035. end
  1036. exports.disable_debug_logging = function()
  1037. unconditional_debug = false
  1038. end
  1039. --[[[
  1040. -- @function lua_util.debugm(module, [log_object], format, ...)
  1041. -- Performs fast debug log for a specific module
  1042. --]]
  1043. exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...)
  1044. if unconditional_debug or debug_modules[mod] then
  1045. if type(obj_or_fmt) == 'string' then
  1046. logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...)
  1047. else
  1048. logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...)
  1049. end
  1050. end
  1051. end
  1052. --[[[
  1053. -- @function lua_util.add_debug_alias(mod, alias)
  1054. -- Add debugging alias so logging to `alias` will be treated as logging to `mod`
  1055. --]]
  1056. exports.add_debug_alias = function(mod, alias)
  1057. debug_aliases[alias] = mod
  1058. if debug_modules[mod] then
  1059. debug_modules[alias] = true
  1060. logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)',
  1061. alias, mod)
  1062. end
  1063. end
  1064. ---[[[
  1065. -- @function lua_util.get_task_verdict(task)
  1066. -- Returns verdict for a task + score if certain, must be called from idempotent filters only
  1067. -- Returns string:
  1068. -- * `spam`: if message have over reject threshold and has more than one positive rule
  1069. -- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules
  1070. -- * `passthrough`: if a message has been passed through some short-circuit rule
  1071. -- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score
  1072. -- * `uncertain`: all other cases
  1073. --]]
  1074. exports.get_task_verdict = function(task)
  1075. local lua_verdict = require "lua_verdict"
  1076. return lua_verdict.get_default_verdict(task)
  1077. end
  1078. ---[[[
  1079. -- @function lua_util.maybe_obfuscate_string(subject, settings, prefix)
  1080. -- Obfuscate string if enabled in settings. Also checks utf8 validity - if
  1081. -- string is not valid utf8 then '???' is returned. Empty string returned as is.
  1082. -- Supported settings:
  1083. -- * <prefix>_privacy = false - subject privacy is off
  1084. -- * <prefix>_privacy_alg = 'blake2' - default hash-algorithm to obfuscate subject
  1085. -- * <prefix>_privacy_prefix = 'obf' - prefix to show it's obfuscated
  1086. -- * <prefix>_privacy_length = 16 - cut the length of the hash; if 0 or fasle full hash is returned
  1087. -- @return obfuscated or validated subject
  1088. --]]
  1089. exports.maybe_obfuscate_string = function(subject, settings, prefix)
  1090. local hash = require 'rspamd_cryptobox_hash'
  1091. if not subject or subject == '' then
  1092. return subject
  1093. elseif not rspamd_util.is_valid_utf8(subject) then
  1094. subject = '???'
  1095. elseif settings[prefix .. '_privacy'] then
  1096. local hash_alg = settings[prefix .. '_privacy_alg'] or 'blake2'
  1097. local subject_hash = hash.create_specific(hash_alg, subject)
  1098. local strip_len = settings[prefix .. '_privacy_length']
  1099. if strip_len and strip_len > 0 then
  1100. subject = subject_hash:hex():sub(1, strip_len)
  1101. else
  1102. subject = subject_hash:hex()
  1103. end
  1104. local privacy_prefix = settings[prefix .. '_privacy_prefix']
  1105. if privacy_prefix and #privacy_prefix > 0 then
  1106. subject = privacy_prefix .. ':' .. subject
  1107. end
  1108. end
  1109. return subject
  1110. end
  1111. ---[[[
  1112. -- @function lua_util.callback_from_string(str)
  1113. -- Converts a string like `return function(...) end` to lua function and return true and this function
  1114. -- or returns false + error message
  1115. -- @return status code and function object or an error message
  1116. --]]]
  1117. exports.callback_from_string = function(s)
  1118. local loadstring = loadstring or load
  1119. if not s or #s == 0 then
  1120. return false, 'invalid or empty string'
  1121. end
  1122. s = exports.rspamd_str_trim(s)
  1123. local inp
  1124. if s:match('^return%s*function') then
  1125. -- 'return function', can be evaluated directly
  1126. inp = s
  1127. elseif s:match('^function%s*%(') then
  1128. inp = 'return ' .. s
  1129. else
  1130. -- Just a plain sequence
  1131. inp = 'return function(...)\n' .. s .. '; end'
  1132. end
  1133. local ret, res_or_err = pcall(loadstring(inp))
  1134. if not ret or type(res_or_err) ~= 'function' then
  1135. return false, res_or_err
  1136. end
  1137. return ret, res_or_err
  1138. end
  1139. ---[[[
  1140. -- @function lua_util.keys(t)
  1141. -- Returns all keys from a specific table
  1142. -- @param {table} t input table (or iterator triplet)
  1143. -- @return array of keys
  1144. --]]]
  1145. exports.keys = function(gen, param, state)
  1146. local keys = {}
  1147. local i = 1
  1148. if param then
  1149. for k, _ in fun.iter(gen, param, state) do
  1150. rawset(keys, i, k)
  1151. i = i + 1
  1152. end
  1153. else
  1154. for k, _ in pairs(gen) do
  1155. rawset(keys, i, k)
  1156. i = i + 1
  1157. end
  1158. end
  1159. return keys
  1160. end
  1161. ---[[[
  1162. -- @function lua_util.values(t)
  1163. -- Returns all values from a specific table
  1164. -- @param {table} t input table
  1165. -- @return array of values
  1166. --]]]
  1167. exports.values = function(gen, param, state)
  1168. local values = {}
  1169. local i = 1
  1170. if param then
  1171. for _, v in fun.iter(gen, param, state) do
  1172. rawset(values, i, v)
  1173. i = i + 1
  1174. end
  1175. else
  1176. for _, v in pairs(gen) do
  1177. rawset(values, i, v)
  1178. i = i + 1
  1179. end
  1180. end
  1181. return values
  1182. end
  1183. ---[[[
  1184. -- @function lua_util.distance_sorted(t1, t2)
  1185. -- Returns distance between two sorted tables t1 and t2
  1186. -- @param {table} t1 input table
  1187. -- @param {table} t2 input table
  1188. -- @return distance between `t1` and `t2`
  1189. --]]]
  1190. exports.distance_sorted = function(t1, t2)
  1191. local ncomp = #t1
  1192. local ndiff = 0
  1193. local i, j = 1, 1
  1194. if ncomp < #t2 then
  1195. ncomp = #t2
  1196. end
  1197. for _ = 1, ncomp do
  1198. if j > #t2 then
  1199. ndiff = ndiff + ncomp - #t2
  1200. if i > j then
  1201. ndiff = ndiff - (i - j)
  1202. end
  1203. break
  1204. elseif i > #t1 then
  1205. ndiff = ndiff + ncomp - #t1
  1206. if j > i then
  1207. ndiff = ndiff - (j - i)
  1208. end
  1209. break
  1210. end
  1211. if t1[i] == t2[j] then
  1212. i = i + 1
  1213. j = j + 1
  1214. elseif t1[i] < t2[j] then
  1215. i = i + 1
  1216. ndiff = ndiff + 1
  1217. else
  1218. j = j + 1
  1219. ndiff = ndiff + 1
  1220. end
  1221. end
  1222. return ndiff
  1223. end
  1224. ---[[[
  1225. -- @function lua_util.table_digest(t)
  1226. -- Returns hash of all values if t[1] is string or all keys/values otherwise
  1227. -- @param {table} t input array or map
  1228. -- @return {string} base32 representation of blake2b hash of all strings
  1229. --]]]
  1230. local function table_digest(t)
  1231. local cr = require "rspamd_cryptobox_hash"
  1232. local h = cr.create()
  1233. if t[1] then
  1234. for _, e in ipairs(t) do
  1235. if type(e) == 'table' then
  1236. h:update(table_digest(e))
  1237. else
  1238. h:update(tostring(e))
  1239. end
  1240. end
  1241. else
  1242. for k, v in pairs(t) do
  1243. h:update(tostring(k))
  1244. if type(v) == 'string' then
  1245. h:update(v)
  1246. elseif type(v) == 'table' then
  1247. h:update(table_digest(v))
  1248. end
  1249. end
  1250. end
  1251. return h:base32()
  1252. end
  1253. exports.table_digest = table_digest
  1254. ---[[[
  1255. -- @function lua_util.toboolean(v)
  1256. -- Converts a string or a number to boolean
  1257. -- @param {string|number} v
  1258. -- @return {boolean} v converted to boolean
  1259. --]]]
  1260. exports.toboolean = function(v)
  1261. local true_t = {
  1262. ['1'] = true,
  1263. ['true'] = true,
  1264. ['TRUE'] = true,
  1265. ['True'] = true,
  1266. };
  1267. local false_t = {
  1268. ['0'] = false,
  1269. ['false'] = false,
  1270. ['FALSE'] = false,
  1271. ['False'] = false,
  1272. };
  1273. if type(v) == 'string' then
  1274. if true_t[v] == true then
  1275. return true;
  1276. elseif false_t[v] == false then
  1277. return false;
  1278. else
  1279. return false, string.format('cannot convert %q to boolean', v);
  1280. end
  1281. elseif type(v) == 'number' then
  1282. return v ~= 0
  1283. else
  1284. return false, string.format('cannot convert %q to boolean', v);
  1285. end
  1286. end
  1287. ---[[[
  1288. -- @function lua_util.config_check_local_or_authed(config, modname)
  1289. -- Reads check_local and check_authed from the config as this is used in many modules
  1290. -- @param {rspamd_config} config `rspamd_config` global
  1291. -- @param {name} module name
  1292. -- @return {boolean} v converted to boolean
  1293. --]]]
  1294. exports.config_check_local_or_authed = function(rspamd_config, modname, def_local, def_authed)
  1295. local check_local = def_local or false
  1296. local check_authed = def_authed or false
  1297. local function try_section(where)
  1298. local ret = false
  1299. local opts = rspamd_config:get_all_opt(where)
  1300. if type(opts) == 'table' then
  1301. if type(opts['check_local']) == 'boolean' then
  1302. check_local = opts['check_local']
  1303. ret = true
  1304. end
  1305. if type(opts['check_authed']) == 'boolean' then
  1306. check_authed = opts['check_authed']
  1307. ret = true
  1308. end
  1309. end
  1310. return ret
  1311. end
  1312. if not try_section(modname) then
  1313. try_section('options')
  1314. end
  1315. return { check_local, check_authed }
  1316. end
  1317. ---[[[
  1318. -- @function lua_util.is_skip_local_or_authed(task, conf[, ip])
  1319. -- Returns `true` if local or authenticated task should be skipped for this module
  1320. -- @param {rspamd_task} task
  1321. -- @param {table} conf table returned from `config_check_local_or_authed`
  1322. -- @param {rspamd_ip} ip optional ip address (can be obtained from a task)
  1323. -- @return {boolean} true if check should be skipped
  1324. --]]]
  1325. exports.is_skip_local_or_authed = function(task, conf, ip)
  1326. if not ip then
  1327. ip = task:get_from_ip()
  1328. end
  1329. if not conf then
  1330. conf = { false, false }
  1331. end
  1332. if ((not conf[2] and task:get_user()) or
  1333. (not conf[1] and type(ip) == 'userdata' and ip:is_local())) then
  1334. return true
  1335. end
  1336. return false
  1337. end
  1338. ---[[[
  1339. -- @function lua_util.maybe_smtp_quote_value(str)
  1340. -- Checks string for the forbidden elements (tspecials in RFC and quote string if needed)
  1341. -- @param {string} str input string
  1342. -- @return {string} original or quoted string
  1343. --]]]
  1344. local tspecial = lpeg.S "()<>,;:\\\"/[]?= \t\v"
  1345. local special_match = lpeg.P((1 - tspecial) ^ 0 * tspecial ^ 1)
  1346. exports.maybe_smtp_quote_value = function(str)
  1347. if special_match:match(str) then
  1348. return string.format('"%s"', str:gsub('"', '\\"'))
  1349. end
  1350. return str
  1351. end
  1352. ---[[[
  1353. -- @function lua_util.shuffle(table)
  1354. -- Performs in-place shuffling of a table
  1355. -- @param {table} tbl table to shuffle
  1356. -- @return {table} same table
  1357. --]]]
  1358. exports.shuffle = function(tbl)
  1359. local size = #tbl
  1360. for i = size, 1, -1 do
  1361. local rand = math.random(size)
  1362. tbl[i], tbl[rand] = tbl[rand], tbl[i]
  1363. end
  1364. return tbl
  1365. end
  1366. --
  1367. local hex_table = {}
  1368. for idx = 0, 255 do
  1369. hex_table[("%02X"):format(idx)] = string.char(idx)
  1370. hex_table[("%02x"):format(idx)] = string.char(idx)
  1371. end
  1372. ---[[[
  1373. -- @function lua_util.unhex(str)
  1374. -- Decode hex encoded string
  1375. -- @param {string} str string to decode
  1376. -- @return {string} hex decoded string (valid hex pairs are decoded, everything else is printed as is)
  1377. --]]]
  1378. exports.unhex = function(str)
  1379. return str:gsub('(..)', hex_table)
  1380. end
  1381. local http_upstream_lists = {}
  1382. local function http_upstreams_by_url(pool, url)
  1383. local rspamd_url = require "rspamd_url"
  1384. local cached = http_upstream_lists[url]
  1385. if cached then
  1386. return cached
  1387. end
  1388. local real_url = rspamd_url.create(pool, url)
  1389. if not real_url then
  1390. return nil
  1391. end
  1392. local host = real_url:get_host()
  1393. local proto = real_url:get_protocol() or 'http'
  1394. local port = real_url:get_port() or (proto == 'https' and 443 or 80)
  1395. local upstream_list = require "rspamd_upstream_list"
  1396. local upstreams = upstream_list.create(host, port)
  1397. if upstreams then
  1398. http_upstream_lists[url] = upstreams
  1399. return upstreams
  1400. end
  1401. return nil
  1402. end
  1403. ---[[[
  1404. -- @function lua_util.http_upstreams_by_url(pool, url)
  1405. -- Returns a cached or new upstreams list that corresponds to the specific url
  1406. -- @param {mempool} pool memory pool to use (typically static pool from rspamd_config)
  1407. -- @param {string} url full url
  1408. -- @return {upstreams_list} object to get upstream from an url
  1409. --]]]
  1410. exports.http_upstreams_by_url = http_upstreams_by_url
  1411. ---[[[
  1412. -- @function lua_util.dns_timeout_augmentation(cfg)
  1413. -- Returns an augmentation suitable to define DNS timeout for a module
  1414. -- @return {string} a string in format 'timeout=x' where `x` is a number of seconds for DNS timeout
  1415. --]]]
  1416. local function dns_timeout_augmentation(cfg)
  1417. return string.format('timeout=%f', cfg:get_dns_timeout() or 0.0)
  1418. end
  1419. exports.dns_timeout_augmentation = dns_timeout_augmentation
  1420. ---[[[
  1421. --- @function lua_util.strip_lua_comments(lua_code)
  1422. -- Strips single-line and multi-line comments from a given Lua code string and removes
  1423. -- any extra spaces or newlines.
  1424. --
  1425. -- @param lua_code The Lua code string to strip comments from.
  1426. -- @return The resulting Lua code string with comments and extra spaces removed.
  1427. --
  1428. ---]]]
  1429. local function strip_lua_comments(lua_code)
  1430. -- Remove single-line comments
  1431. lua_code = lua_code:gsub("%-%-[^\r\n]*", "")
  1432. -- Remove multi-line comments
  1433. lua_code = lua_code:gsub("%-%-%[%[.-%]%]", "")
  1434. -- Remove extra spaces and newlines
  1435. lua_code = lua_code:gsub("%s+", " ")
  1436. return lua_code
  1437. end
  1438. exports.strip_lua_comments = strip_lua_comments
  1439. ---[[[
  1440. -- @function lua_util.join_path(...)
  1441. -- Joins path components into a single path string using the appropriate separator
  1442. -- for the current operating system.
  1443. --
  1444. -- @param ... Any number of path components to join together.
  1445. -- @return A single path string, with components separated by the appropriate separator.
  1446. --
  1447. ---]]]
  1448. local path_sep = package.config:sub(1, 1) or '/'
  1449. local function join_path(...)
  1450. local components = { ... }
  1451. -- Join components using separator
  1452. return table.concat(components, path_sep)
  1453. end
  1454. exports.join_path = join_path
  1455. -- Short unit test for sanity
  1456. if path_sep == '/' then
  1457. assert(join_path('/path', 'to', 'file') == '/path/to/file')
  1458. else
  1459. assert(join_path('C:', 'path', 'to', 'file') == 'C:\\path\\to\\file')
  1460. end
  1461. -- Defines symbols priorities for common usage in prefilters/postfilters
  1462. exports.symbols_priorities = {
  1463. top = 10, -- Symbols must be executed first (or last), such as settings
  1464. high = 9, -- Example: asn
  1465. medium = 5, -- Everything should use this as default
  1466. low = 0,
  1467. }
  1468. return exports