You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_util.lua 42KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639
  1. --[[
  2. Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_util
  15. -- This module contains utility functions for working with Lua and/or Rspamd
  16. --]]
  17. local exports = {}
  18. local lpeg = require 'lpeg'
  19. local rspamd_util = require "rspamd_util"
  20. local fun = require "fun"
  21. local lupa = require "lupa"
  22. local split_grammar = {}
  23. local spaces_split_grammar
  24. local space = lpeg.S ' \t\n\v\f\r'
  25. local nospace = 1 - space
  26. local ptrim = space ^ 0 * lpeg.C((space ^ 0 * nospace ^ 1) ^ 0)
  27. local match = lpeg.match
  28. lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', {
  29. keep_trailing_newline = true,
  30. autoescape = false,
  31. })
  32. lupa.filters.pbkdf = function(s)
  33. local cr = require "rspamd_cryptobox"
  34. return cr.pbkdf(s)
  35. end
  36. local function rspamd_str_split(s, sep)
  37. local gr
  38. if not sep then
  39. if not spaces_split_grammar then
  40. local _sep = space
  41. local elem = lpeg.C((1 - _sep) ^ 0)
  42. local p = lpeg.Ct(elem * (_sep * elem) ^ 0)
  43. spaces_split_grammar = p
  44. end
  45. gr = spaces_split_grammar
  46. else
  47. gr = split_grammar[sep]
  48. if not gr then
  49. local _sep
  50. if type(sep) == 'string' then
  51. _sep = lpeg.S(sep) -- Assume set
  52. else
  53. _sep = sep -- Assume lpeg object
  54. end
  55. local elem = lpeg.C((1 - _sep) ^ 0)
  56. local p = lpeg.Ct(elem * (_sep * elem) ^ 0)
  57. gr = p
  58. split_grammar[sep] = gr
  59. end
  60. end
  61. return gr:match(s)
  62. end
  63. --[[[
  64. -- @function lua_util.str_split(text, delimiter)
  65. -- Splits text into a numeric table by delimiter
  66. -- @param {string} text delimited text
  67. -- @param {string} delimiter the delimiter
  68. -- @return {table} numeric table containing string parts
  69. --]]
  70. exports.rspamd_str_split = rspamd_str_split
  71. exports.str_split = rspamd_str_split
  72. local function rspamd_str_trim(s)
  73. return match(ptrim, s)
  74. end
  75. exports.rspamd_str_trim = rspamd_str_trim
  76. --[[[
  77. -- @function lua_util.str_trim(text)
  78. -- Returns a string with no trailing and leading spaces
  79. -- @param {string} text input text
  80. -- @return {string} string with no trailing and leading spaces
  81. --]]
  82. exports.str_trim = rspamd_str_trim
  83. --[[[
  84. -- @function lua_util.str_startswith(text, prefix)
  85. -- @param {string} text
  86. -- @param {string} prefix
  87. -- @return {boolean} true if text starts with the specified prefix, false otherwise
  88. --]]
  89. exports.str_startswith = function(s, prefix)
  90. return s:sub(1, prefix:len()) == prefix
  91. end
  92. --[[[
  93. -- @function lua_util.str_endswith(text, suffix)
  94. -- @param {string} text
  95. -- @param {string} suffix
  96. -- @return {boolean} true if text ends with the specified suffix, false otherwise
  97. --]]
  98. exports.str_endswith = function(s, suffix)
  99. return s:find(suffix, -suffix:len(), true) ~= nil
  100. end
  101. --[[[
  102. -- @function lua_util.round(number, decimalPlaces)
  103. -- Round number to fixed number of decimal points
  104. -- @param {number} number number to round
  105. -- @param {number} decimalPlaces number of decimal points
  106. -- @return {number} rounded number
  107. --]]
  108. -- modified version from Robert Jay Gould http://lua-users.org/wiki/SimpleRound
  109. exports.round = function(num, numDecimalPlaces)
  110. local mult = 10 ^ (numDecimalPlaces or 0)
  111. if num >= 0 then
  112. return math.floor(num * mult + 0.5) / mult
  113. else
  114. return math.ceil(num * mult - 0.5) / mult
  115. end
  116. end
  117. --[[[
  118. -- @function lua_util.template(text, replacements)
  119. -- Replaces values in a text template
  120. -- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces.
  121. -- @param {string} text text containing variables
  122. -- @param {table} replacements key/value pairs for replacements
  123. -- @return {string} string containing replaced values
  124. -- @example
  125. -- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  126. -- -- goop contains "HELLO LUA WORLD!"
  127. --]]
  128. exports.template = function(tmpl, keys)
  129. local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
  130. local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit ^ 1) / keys) }
  131. local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit ^ 1) / keys) * (lpeg.P("}") / "") }
  132. local template_grammar = lpeg.Cs((var + var_braced + 1) ^ 0)
  133. return lpeg.match(template_grammar, tmpl)
  134. end
  135. local function enrich_template_with_globals(env)
  136. local newenv = exports.shallowcopy(env)
  137. newenv.paths = rspamd_paths
  138. newenv.env = rspamd_env
  139. return newenv
  140. end
  141. --[[[
  142. -- @function lua_util.jinja_template(text, env[, skip_global_env])
  143. -- Replaces values in a text template according to jinja2 syntax
  144. -- @param {string} text text containing variables
  145. -- @param {table} replacements key/value pairs for replacements
  146. -- @param {boolean} skip_global_env don't export Rspamd superglobals
  147. -- @return {string} string containing replaced values
  148. -- @example
  149. -- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  150. -- "HELLO LUA WORLD!"
  151. --]]
  152. exports.jinja_template = function(text, env, skip_global_env)
  153. if not skip_global_env then
  154. env = enrich_template_with_globals(env)
  155. end
  156. return lupa.expand(text, env)
  157. end
  158. --[[[
  159. -- @function lua_util.jinja_file(filename, env[, skip_global_env])
  160. -- Replaces values in a text template according to jinja2 syntax
  161. -- @param {string} filename name of file to expand
  162. -- @param {table} replacements key/value pairs for replacements
  163. -- @param {boolean} skip_global_env don't export Rspamd superglobals
  164. -- @return {string} string containing replaced values
  165. -- @example
  166. -- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  167. -- "HELLO LUA WORLD!"
  168. --]]
  169. exports.jinja_template_file = function(filename, env, skip_global_env)
  170. if not skip_global_env then
  171. env = enrich_template_with_globals(env)
  172. end
  173. return lupa.expand_file(filename, env)
  174. end
  175. exports.remove_email_aliases = function(email_addr)
  176. local function check_gmail_user(addr)
  177. -- Remove all points
  178. local no_dots_user = string.gsub(addr.user, '%.', '')
  179. local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$')
  180. if cap then
  181. return cap, rspamd_str_split(pluses, '+'), nil
  182. elseif no_dots_user ~= addr.user then
  183. return no_dots_user, {}, nil
  184. end
  185. return nil
  186. end
  187. local function check_address(addr)
  188. if addr.user then
  189. local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$')
  190. if cap then
  191. return cap, rspamd_str_split(pluses, '+'), nil
  192. end
  193. end
  194. return nil
  195. end
  196. local function set_addr(addr, new_user, new_domain)
  197. if new_user then
  198. addr.user = new_user
  199. end
  200. if new_domain then
  201. addr.domain = new_domain
  202. end
  203. if addr.domain then
  204. addr.addr = string.format('%s@%s', addr.user, addr.domain)
  205. else
  206. addr.addr = string.format('%s@', addr.user)
  207. end
  208. if addr.name and #addr.name > 0 then
  209. addr.raw = string.format('"%s" <%s>', addr.name, addr.addr)
  210. else
  211. addr.raw = string.format('<%s>', addr.addr)
  212. end
  213. end
  214. local function check_gmail(addr)
  215. local nu, tags, nd = check_gmail_user(addr)
  216. if nu then
  217. return nu, tags, nd
  218. end
  219. return nil
  220. end
  221. local function check_googlemail(addr)
  222. local nd = 'gmail.com'
  223. local nu, tags = check_gmail_user(addr)
  224. if nu then
  225. return nu, tags, nd
  226. end
  227. return nil, nil, nd
  228. end
  229. local specific_domains = {
  230. ['gmail.com'] = check_gmail,
  231. ['googlemail.com'] = check_googlemail,
  232. }
  233. if email_addr then
  234. if email_addr.domain and specific_domains[email_addr.domain] then
  235. local nu, tags, nd = specific_domains[email_addr.domain](email_addr)
  236. if nu or nd then
  237. set_addr(email_addr, nu, nd)
  238. return nu, tags
  239. end
  240. else
  241. local nu, tags, nd = check_address(email_addr)
  242. if nu or nd then
  243. set_addr(email_addr, nu, nd)
  244. return nu, tags
  245. end
  246. end
  247. return nil
  248. end
  249. end
  250. exports.is_rspamc_or_controller = function(task)
  251. local ua = task:get_request_header('User-Agent') or ''
  252. local pwd = task:get_request_header('Password')
  253. local is_rspamc = false
  254. if tostring(ua) == 'rspamc' or pwd then
  255. is_rspamc = true
  256. end
  257. return is_rspamc
  258. end
  259. --[[[
  260. -- @function lua_util.unpack(table)
  261. -- Converts numeric table to varargs
  262. -- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3
  263. -- @param {table} table numerically indexed table to unpack
  264. -- @return {varargs} unpacked table elements
  265. --]]
  266. local unpack_function = table.unpack or unpack
  267. exports.unpack = function(t)
  268. return unpack_function(t)
  269. end
  270. --[[[
  271. -- @function lua_util.flatten(table)
  272. -- Flatten underlying tables in a single table
  273. -- @param {table} table table of tables
  274. -- @return {table} flattened table
  275. --]]
  276. exports.flatten = function(t)
  277. local res = {}
  278. for _, e in fun.iter(t) do
  279. for _, v in fun.iter(e) do
  280. res[#res + 1] = v
  281. end
  282. end
  283. return res
  284. end
  285. --[[[
  286. -- @function lua_util.spairs(table)
  287. -- Like `pairs` but keys are sorted lexicographically
  288. -- @param {table} table table containing key/value pairs
  289. -- @return {function} generator function returning key/value pairs
  290. --]]
  291. -- Sorted iteration:
  292. -- for k,v in spairs(t) do ... end
  293. --
  294. -- or with custom comparison:
  295. -- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end)
  296. --
  297. -- optional limit is also available (e.g. return top X elements)
  298. local function spairs(t, order, lim)
  299. -- collect the keys
  300. local keys = {}
  301. for k in pairs(t) do
  302. keys[#keys + 1] = k
  303. end
  304. -- if order function given, sort by it by passing the table and keys a, b,
  305. -- otherwise just sort the keys
  306. if order then
  307. table.sort(keys, function(a, b)
  308. return order(t, a, b)
  309. end)
  310. else
  311. table.sort(keys)
  312. end
  313. -- return the iterator function
  314. local i = 0
  315. return function()
  316. i = i + 1
  317. if not lim or i <= lim then
  318. if keys[i] then
  319. return keys[i], t[keys[i]]
  320. end
  321. end
  322. end
  323. end
  324. exports.spairs = spairs
  325. local lua_cfg_utils = require "lua_cfg_utils"
  326. exports.config_utils = lua_cfg_utils
  327. exports.disable_module = lua_cfg_utils.disable_module
  328. --[[[
  329. -- @function lua_util.disable_module(modname)
  330. -- Checks experimental plugins state and disable if needed
  331. -- @param {string} modname name of plugin to check
  332. -- @return {boolean} true if plugin should be enabled, false otherwise
  333. --]]
  334. local function check_experimental(modname)
  335. if rspamd_config:experimental_enabled() then
  336. return true
  337. else
  338. lua_cfg_utils.disable_module(modname, 'experimental')
  339. end
  340. return false
  341. end
  342. exports.check_experimental = check_experimental
  343. --[[[
  344. -- @function lua_util.list_to_hash(list)
  345. -- Converts numerically-indexed table to table indexed by values
  346. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  347. -- @return {table} table indexed by values
  348. -- @example
  349. -- local h = lua_util.list_to_hash({"a", "b"})
  350. -- -- h contains {a = true, b = true}
  351. --]]
  352. local function list_to_hash(list)
  353. if type(list) == 'table' then
  354. if list[1] then
  355. local h = {}
  356. for _, e in ipairs(list) do
  357. h[e] = true
  358. end
  359. return h
  360. else
  361. return list
  362. end
  363. elseif type(list) == 'string' then
  364. local h = {}
  365. h[list] = true
  366. return h
  367. end
  368. end
  369. exports.list_to_hash = list_to_hash
  370. --[[[
  371. -- @function lua_util.nkeys(table|gen, param, state)
  372. -- Returns number of keys in a table (i.e. from both the array and hash parts combined)
  373. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  374. -- @return {number} number of keys
  375. -- @example
  376. -- print(lua_util.nkeys({})) -- 0
  377. -- print(lua_util.nkeys({ "a", nil, "b" })) -- 2
  378. -- print(lua_util.nkeys({ dog = 3, cat = 4, bird = nil })) -- 2
  379. -- print(lua_util.nkeys({ "a", dog = 3, cat = 4 })) -- 3
  380. --
  381. --]]
  382. local function nkeys(gen, param, state)
  383. local n = 0
  384. if not param then
  385. for _, _ in pairs(gen) do
  386. n = n + 1
  387. end
  388. else
  389. for _, _ in fun.iter(gen, param, state) do
  390. n = n + 1
  391. end
  392. end
  393. return n
  394. end
  395. exports.nkeys = nkeys
  396. --[[[
  397. -- @function lua_util.parse_time_interval(str)
  398. -- Parses human readable time interval
  399. -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days,
  400. -- 'w' for weeks, 'y' for years
  401. -- @param {string} str input string
  402. -- @return {number|nil} parsed interval as seconds (might be fractional)
  403. --]]
  404. local function parse_time_interval(str)
  405. local function parse_time_suffix(s)
  406. if s == 's' then
  407. return 1
  408. elseif s == 'm' then
  409. return 60
  410. elseif s == 'h' then
  411. return 3600
  412. elseif s == 'd' then
  413. return 86400
  414. elseif s == 'w' then
  415. return 86400 * 7
  416. elseif s == 'y' then
  417. return 365 * 86400;
  418. end
  419. end
  420. local digit = lpeg.R("09")
  421. local parser = {}
  422. parser.integer = (lpeg.S("+-") ^ -1) *
  423. (digit ^ 1)
  424. parser.fractional = (lpeg.P(".")) *
  425. (digit ^ 1)
  426. parser.number = (parser.integer *
  427. (parser.fractional ^ -1)) +
  428. (lpeg.S("+-") * parser.fractional)
  429. parser.time = lpeg.Cf(lpeg.Cc(1) *
  430. (parser.number / tonumber) *
  431. ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1),
  432. function(acc, val)
  433. return acc * val
  434. end)
  435. local t = lpeg.match(parser.time, str)
  436. return t
  437. end
  438. exports.parse_time_interval = parse_time_interval
  439. --[[[
  440. -- @function lua_util.dehumanize_number(str)
  441. -- Parses human readable number
  442. -- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier,
  443. -- e.g. `10mb` equal to `10 * 1024 * 1024`
  444. -- @param {string} str input string
  445. -- @return {number|nil} parsed number
  446. --]]
  447. local function dehumanize_number(str)
  448. local function parse_suffix(s)
  449. if s == 'k' then
  450. return 1000
  451. elseif s == 'm' then
  452. return 1000000
  453. elseif s == 'g' then
  454. return 1e9
  455. elseif s == 'kb' then
  456. return 1024
  457. elseif s == 'mb' then
  458. return 1024 * 1024
  459. elseif s == 'gb' then
  460. return 1024 * 1024;
  461. end
  462. end
  463. local digit = lpeg.R("09")
  464. local parser = {}
  465. parser.integer = (lpeg.S("+-") ^ -1) *
  466. (digit ^ 1)
  467. parser.fractional = (lpeg.P(".")) *
  468. (digit ^ 1)
  469. parser.number = (parser.integer *
  470. (parser.fractional ^ -1)) +
  471. (lpeg.S("+-") * parser.fractional)
  472. parser.humanized_number = lpeg.Cf(lpeg.Cc(1) *
  473. (parser.number / tonumber) *
  474. (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1),
  475. function(acc, val)
  476. return acc * val
  477. end)
  478. local t = lpeg.match(parser.humanized_number, str)
  479. return t
  480. end
  481. exports.dehumanize_number = dehumanize_number
  482. --[[[
  483. -- @function lua_util.table_cmp(t1, t2)
  484. -- Compare two tables deeply
  485. --]]
  486. local function table_cmp(table1, table2)
  487. local avoid_loops = {}
  488. local function recurse(t1, t2)
  489. if type(t1) ~= type(t2) then
  490. return false
  491. end
  492. if type(t1) ~= "table" then
  493. return t1 == t2
  494. end
  495. if avoid_loops[t1] then
  496. return avoid_loops[t1] == t2
  497. end
  498. avoid_loops[t1] = t2
  499. -- Copy keys from t2
  500. local t2keys = {}
  501. local t2tablekeys = {}
  502. for k, _ in pairs(t2) do
  503. if type(k) == "table" then
  504. table.insert(t2tablekeys, k)
  505. end
  506. t2keys[k] = true
  507. end
  508. -- Let's iterate keys from t1
  509. for k1, v1 in pairs(t1) do
  510. local v2 = t2[k1]
  511. if type(k1) == "table" then
  512. -- if key is a table, we need to find an equivalent one.
  513. local ok = false
  514. for i, tk in ipairs(t2tablekeys) do
  515. if table_cmp(k1, tk) and recurse(v1, t2[tk]) then
  516. table.remove(t2tablekeys, i)
  517. t2keys[tk] = nil
  518. ok = true
  519. break
  520. end
  521. end
  522. if not ok then
  523. return false
  524. end
  525. else
  526. -- t1 has a key which t2 doesn't have, fail.
  527. if v2 == nil then
  528. return false
  529. end
  530. t2keys[k1] = nil
  531. if not recurse(v1, v2) then
  532. return false
  533. end
  534. end
  535. end
  536. -- if t2 has a key which t1 doesn't have, fail.
  537. if next(t2keys) then
  538. return false
  539. end
  540. return true
  541. end
  542. return recurse(table1, table2)
  543. end
  544. exports.table_cmp = table_cmp
  545. --[[[
  546. -- @function lua_util.table_merge(t1, t2)
  547. -- Merge two tables
  548. --]]
  549. local function table_merge(t1, t2)
  550. local res = {}
  551. local nidx = 1 -- for numeric indicies
  552. local it_func = function(k, v)
  553. if type(k) == 'number' then
  554. res[nidx] = v
  555. nidx = nidx + 1
  556. else
  557. res[k] = v
  558. end
  559. end
  560. for k, v in pairs(t1) do
  561. it_func(k, v)
  562. end
  563. for k, v in pairs(t2) do
  564. it_func(k, v)
  565. end
  566. return res
  567. end
  568. exports.table_merge = table_merge
  569. --[[[
  570. -- @function lua_util.table_cmp(task, name, value, stop_chars)
  571. -- Performs header folding
  572. --]]
  573. exports.fold_header = function(task, name, value, stop_chars)
  574. local how
  575. if task:has_flag("milter") then
  576. how = "lf"
  577. else
  578. how = task:get_newlines_type()
  579. end
  580. return rspamd_util.fold_header(name, value, how, stop_chars)
  581. end
  582. --[[[
  583. -- @function lua_util.override_defaults(defaults, override)
  584. -- Overrides values from defaults with override
  585. --]]
  586. local function override_defaults(def, override)
  587. -- Corner cases
  588. if not override or type(override) ~= 'table' then
  589. return def
  590. end
  591. if not def or type(def) ~= 'table' then
  592. return override
  593. end
  594. local res = {}
  595. for k, v in pairs(override) do
  596. if type(v) == 'table' then
  597. if def[k] and type(def[k]) == 'table' then
  598. -- Recursively override elements
  599. res[k] = override_defaults(def[k], v)
  600. else
  601. res[k] = v
  602. end
  603. else
  604. res[k] = v
  605. end
  606. end
  607. for k, v in pairs(def) do
  608. if type(res[k]) == 'nil' then
  609. res[k] = v
  610. end
  611. end
  612. return res
  613. end
  614. exports.override_defaults = override_defaults
  615. --[[[
  616. -- @function lua_util.filter_specific_urls(urls, params)
  617. -- params: {
  618. - - task - if needed to save in the cache
  619. - - limit <int> (default = 9999)
  620. - - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
  621. works only if number of unique eSLD less than `limit`
  622. - - need_emails <bool> (default = false)
  623. - - filter <callback> (default = nil)
  624. - - prefix <string> cache prefix (default = nil)
  625. -- }
  626. -- Apply heuristic in extracting of urls from `urls` table, this function
  627. -- tries its best to extract specific number of urls from a task based on
  628. -- their characteristics
  629. --]]
  630. exports.filter_specific_urls = function(urls, params)
  631. local cache_key
  632. if params.task and not params.no_cache then
  633. if params.prefix then
  634. cache_key = params.prefix
  635. else
  636. cache_key = string.format('sp_urls_%d%s%s%s', params.limit,
  637. tostring(params.need_emails or false),
  638. tostring(params.need_images or false),
  639. tostring(params.need_content or false))
  640. end
  641. local cached = params.task:cache_get(cache_key)
  642. if cached then
  643. return cached
  644. end
  645. end
  646. if not urls then
  647. return {}
  648. end
  649. if params.filter then
  650. urls = fun.totable(fun.filter(params.filter, urls))
  651. end
  652. -- Filter by tld:
  653. local tlds = {}
  654. local eslds = {}
  655. local ntlds, neslds = 0, 0
  656. local res = {}
  657. local nres = 0
  658. local function insert_url(str, u)
  659. if not res[str] then
  660. res[str] = u
  661. nres = nres + 1
  662. return true
  663. end
  664. return false
  665. end
  666. local function process_single_url(u, default_priority)
  667. local priority = default_priority or 1 -- Normal priority
  668. local flags = u:get_flags()
  669. if params.ignore_ip and flags.numeric then
  670. return
  671. end
  672. if flags.redirected then
  673. local redir = u:get_redirected() -- get the real url
  674. if params.ignore_redirected then
  675. -- Replace `u` with redir
  676. u = redir
  677. priority = 2
  678. else
  679. -- Process both redirected url and the original one
  680. process_single_url(redir, 2)
  681. end
  682. end
  683. if flags.image then
  684. if not params.need_images then
  685. -- Ignore url
  686. return
  687. else
  688. -- Penalise images in urls
  689. priority = 0
  690. end
  691. end
  692. local esld = u:get_tld()
  693. local str_hash = tostring(u)
  694. if esld then
  695. -- Special cases
  696. if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then
  697. if flags.obscured then
  698. priority = 3
  699. else
  700. if (flags.has_user or flags.has_port) then
  701. priority = 2
  702. elseif (flags.subject or flags.phished) then
  703. priority = 2
  704. end
  705. end
  706. elseif flags.html_displayed then
  707. priority = 0
  708. end
  709. if not eslds[esld] then
  710. eslds[esld] = { { str_hash, u, priority } }
  711. neslds = neslds + 1
  712. else
  713. if #eslds[esld] < params.esld_limit then
  714. table.insert(eslds[esld], { str_hash, u, priority })
  715. end
  716. end
  717. -- eSLD - 1 part => tld
  718. local parts = rspamd_str_split(esld, '.')
  719. local tld = table.concat(fun.totable(fun.tail(parts)), '.')
  720. if not tlds[tld] then
  721. tlds[tld] = { { str_hash, u, priority } }
  722. ntlds = ntlds + 1
  723. else
  724. table.insert(tlds[tld], { str_hash, u, priority })
  725. end
  726. end
  727. end
  728. for _, u in ipairs(urls) do
  729. process_single_url(u)
  730. end
  731. local limit = params.limit
  732. limit = limit - nres
  733. if limit < 0 then
  734. limit = 0
  735. end
  736. if limit == 0 then
  737. res = exports.values(res)
  738. if params.task and not params.no_cache then
  739. params.task:cache_set(cache_key, res)
  740. end
  741. return res
  742. end
  743. -- Sort eSLDs and tlds
  744. local function sort_stuff(tbl)
  745. -- Sort according to max priority
  746. table.sort(tbl, function(e1, e2)
  747. -- Sort by priority so max priority is at the end
  748. table.sort(e1, function(tr1, tr2)
  749. return tr1[3] < tr2[3]
  750. end)
  751. table.sort(e2, function(tr1, tr2)
  752. return tr1[3] < tr2[3]
  753. end)
  754. if e1[#e1][3] ~= e2[#e2][3] then
  755. -- Sort by priority so max priority is at the beginning
  756. return e1[#e1][3] > e2[#e2][3]
  757. else
  758. -- Prefer less urls to more urls per esld
  759. return #e1 < #e2
  760. end
  761. end)
  762. return tbl
  763. end
  764. eslds = sort_stuff(exports.values(eslds))
  765. neslds = #eslds
  766. if neslds <= limit then
  767. -- Number of eslds < limit
  768. repeat
  769. local item_found = false
  770. for _, lurls in ipairs(eslds) do
  771. if #lurls > 0 then
  772. local last = table.remove(lurls)
  773. insert_url(last[1], last[2])
  774. limit = limit - 1
  775. item_found = true
  776. end
  777. end
  778. until limit <= 0 or not item_found
  779. res = exports.values(res)
  780. if params.task and not params.no_cache then
  781. params.task:cache_set(cache_key, res)
  782. end
  783. return res
  784. end
  785. tlds = sort_stuff(exports.values(tlds))
  786. ntlds = #tlds
  787. -- Number of tlds < limit
  788. while limit > 0 do
  789. for _, lurls in ipairs(tlds) do
  790. if #lurls > 0 then
  791. local last = table.remove(lurls)
  792. insert_url(last[1], last[2])
  793. limit = limit - 1
  794. end
  795. if limit == 0 then
  796. break
  797. end
  798. end
  799. end
  800. res = exports.values(res)
  801. if params.task and not params.no_cache then
  802. params.task:cache_set(cache_key, res)
  803. end
  804. return res
  805. end
  806. --[[[
  807. -- @function lua_util.extract_specific_urls(params)
  808. -- params: {
  809. - - task
  810. - - limit <int> (default = 9999)
  811. - - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
  812. works only if number of unique eSLD less than `limit`
  813. - - need_emails <bool> (default = false)
  814. - - filter <callback> (default = nil)
  815. - - prefix <string> cache prefix (default = nil)
  816. - - ignore_redirected <bool> (default = false)
  817. - - need_images <bool> (default = false)
  818. - - need_content <bool> (default = false)
  819. -- }
  820. -- Apply heuristic in extracting of urls from task, this function
  821. -- tries its best to extract specific number of urls from a task based on
  822. -- their characteristics
  823. --]]
  824. -- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
  825. exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix)
  826. local default_params = {
  827. limit = 9999,
  828. esld_limit = 9999,
  829. need_emails = false,
  830. need_images = false,
  831. need_content = false,
  832. filter = nil,
  833. prefix = nil,
  834. ignore_ip = false,
  835. ignore_redirected = false,
  836. no_cache = false,
  837. }
  838. local params
  839. if type(params_or_task) == 'table' and type(lim) == 'nil' then
  840. params = params_or_task
  841. else
  842. -- Deprecated call
  843. params = {
  844. task = params_or_task,
  845. limit = lim,
  846. need_emails = need_emails,
  847. filter = filter,
  848. prefix = prefix
  849. }
  850. end
  851. for k, v in pairs(default_params) do
  852. if type(params[k]) == 'nil' and v ~= nil then
  853. params[k] = v
  854. end
  855. end
  856. local url_params = {
  857. emails = params.need_emails,
  858. images = params.need_images,
  859. content = params.need_content,
  860. flags = params.flags, -- maybe nil
  861. flags_mode = params.flags_mode, -- maybe nil
  862. }
  863. -- Shortcut for cached stuff
  864. if params.task and not params.no_cache then
  865. local cache_key
  866. if params.prefix then
  867. cache_key = params.prefix
  868. else
  869. local cache_key_suffix
  870. if params.flags then
  871. cache_key_suffix = table.concat(params.flags) .. (params.flags_mode or '')
  872. else
  873. cache_key_suffix = string.format('%s%s%s',
  874. tostring(params.need_emails or false),
  875. tostring(params.need_images or false),
  876. tostring(params.need_content or false))
  877. end
  878. cache_key = string.format('sp_urls_%d%s', params.limit, cache_key_suffix)
  879. end
  880. local cached = params.task:cache_get(cache_key)
  881. if cached then
  882. return cached
  883. end
  884. end
  885. -- No cache version
  886. local urls = params.task:get_urls(url_params)
  887. return exports.filter_specific_urls(urls, params)
  888. end
  889. --[[[
  890. -- @function lua_util.deepcopy(table)
  891. -- params: {
  892. - - table
  893. -- }
  894. -- Performs deep copy of the table. Including metatables
  895. --]]
  896. local function deepcopy(orig)
  897. local orig_type = type(orig)
  898. local copy
  899. if orig_type == 'table' then
  900. copy = {}
  901. for orig_key, orig_value in next, orig, nil do
  902. copy[deepcopy(orig_key)] = deepcopy(orig_value)
  903. end
  904. if getmetatable(orig) then
  905. setmetatable(copy, deepcopy(getmetatable(orig)))
  906. end
  907. else
  908. -- number, string, boolean, etc
  909. copy = orig
  910. end
  911. return copy
  912. end
  913. exports.deepcopy = deepcopy
  914. --[[[
  915. -- @function lua_util.deepsort(table)
  916. -- params: {
  917. - - table
  918. -- }
  919. -- Performs recursive in-place sort of a table
  920. --]]
  921. local function default_sort_cmp(e1, e2)
  922. if type(e1) == type(e2) then
  923. return e1 < e2
  924. else
  925. return type(e1) < type(e2)
  926. end
  927. end
  928. local function deepsort(tbl, sort_func)
  929. local orig_type = type(tbl)
  930. if orig_type == 'table' then
  931. table.sort(tbl, sort_func or default_sort_cmp)
  932. for _, orig_value in next, tbl, nil do
  933. deepsort(orig_value)
  934. end
  935. end
  936. end
  937. exports.deepsort = deepsort
  938. --[[[
  939. -- @function lua_util.shallowcopy(tbl)
  940. -- Performs shallow (and fast) copy of a table or another Lua type
  941. --]]
  942. exports.shallowcopy = function(orig)
  943. local orig_type = type(orig)
  944. local copy
  945. if orig_type == 'table' then
  946. copy = {}
  947. for orig_key, orig_value in pairs(orig) do
  948. copy[orig_key] = orig_value
  949. end
  950. else
  951. copy = orig
  952. end
  953. return copy
  954. end
  955. -- Debugging support
  956. local logger = require "rspamd_logger"
  957. local unconditional_debug = logger.log_level() == 'debug'
  958. local debug_modules = {}
  959. local debug_aliases = {}
  960. local log_level = 384 -- debug + forced (1 << 7 | 1 << 8)
  961. exports.init_debug_logging = function(config)
  962. -- Fill debug modules from the config
  963. if not unconditional_debug then
  964. local log_config = config:get_all_opt('logging')
  965. if log_config then
  966. local log_level_str = log_config.level
  967. if log_level_str then
  968. if log_level_str == 'debug' then
  969. unconditional_debug = true
  970. end
  971. end
  972. if log_config.debug_modules then
  973. for _, m in ipairs(log_config.debug_modules) do
  974. debug_modules[m] = true
  975. logger.infox(config, 'enable debug for Lua module %s', m)
  976. end
  977. end
  978. if #debug_aliases > 0 then
  979. for alias, mod in pairs(debug_aliases) do
  980. if debug_modules[mod] then
  981. debug_modules[alias] = true
  982. logger.infox(config, 'enable debug for Lua module %s (%s aliased)',
  983. alias, mod)
  984. end
  985. end
  986. end
  987. end
  988. end
  989. end
  990. exports.enable_debug_logging = function()
  991. unconditional_debug = true
  992. end
  993. exports.enable_debug_modules = function(...)
  994. for _, m in ipairs({ ... }) do
  995. debug_modules[m] = true
  996. end
  997. end
  998. exports.disable_debug_logging = function()
  999. unconditional_debug = false
  1000. end
  1001. --[[[
  1002. -- @function lua_util.debugm(module, [log_object], format, ...)
  1003. -- Performs fast debug log for a specific module
  1004. --]]
  1005. exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...)
  1006. if unconditional_debug or debug_modules[mod] then
  1007. if type(obj_or_fmt) == 'string' then
  1008. logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...)
  1009. else
  1010. logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...)
  1011. end
  1012. end
  1013. end
  1014. --[[[
  1015. -- @function lua_util.add_debug_alias(mod, alias)
  1016. -- Add debugging alias so logging to `alias` will be treated as logging to `mod`
  1017. --]]
  1018. exports.add_debug_alias = function(mod, alias)
  1019. debug_aliases[alias] = mod
  1020. if debug_modules[mod] then
  1021. debug_modules[alias] = true
  1022. logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)',
  1023. alias, mod)
  1024. end
  1025. end
  1026. ---[[[
  1027. -- @function lua_util.get_task_verdict(task)
  1028. -- Returns verdict for a task + score if certain, must be called from idempotent filters only
  1029. -- Returns string:
  1030. -- * `spam`: if message have over reject threshold and has more than one positive rule
  1031. -- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules
  1032. -- * `passthrough`: if a message has been passed through some short-circuit rule
  1033. -- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score
  1034. -- * `uncertain`: all other cases
  1035. --]]
  1036. exports.get_task_verdict = function(task)
  1037. local lua_verdict = require "lua_verdict"
  1038. return lua_verdict.get_default_verdict(task)
  1039. end
  1040. ---[[[
  1041. -- @function lua_util.maybe_obfuscate_string(subject, settings, prefix)
  1042. -- Obfuscate string if enabled in settings. Also checks utf8 validity - if
  1043. -- string is not valid utf8 then '???' is returned. Empty string returned as is.
  1044. -- Supported settings:
  1045. -- * <prefix>_privacy = false - subject privacy is off
  1046. -- * <prefix>_privacy_alg = 'blake2' - default hash-algorithm to obfuscate subject
  1047. -- * <prefix>_privacy_prefix = 'obf' - prefix to show it's obfuscated
  1048. -- * <prefix>_privacy_length = 16 - cut the length of the hash; if 0 or fasle full hash is returned
  1049. -- @return obfuscated or validated subject
  1050. --]]
  1051. exports.maybe_obfuscate_string = function(subject, settings, prefix)
  1052. local hash = require 'rspamd_cryptobox_hash'
  1053. if not subject or subject == '' then
  1054. return subject
  1055. elseif not rspamd_util.is_valid_utf8(subject) then
  1056. subject = '???'
  1057. elseif settings[prefix .. '_privacy'] then
  1058. local hash_alg = settings[prefix .. '_privacy_alg'] or 'blake2'
  1059. local subject_hash = hash.create_specific(hash_alg, subject)
  1060. local strip_len = settings[prefix .. '_privacy_length']
  1061. if strip_len and strip_len > 0 then
  1062. subject = subject_hash:hex():sub(1, strip_len)
  1063. else
  1064. subject = subject_hash:hex()
  1065. end
  1066. local privacy_prefix = settings[prefix .. '_privacy_prefix']
  1067. if privacy_prefix and #privacy_prefix > 0 then
  1068. subject = privacy_prefix .. ':' .. subject
  1069. end
  1070. end
  1071. return subject
  1072. end
  1073. ---[[[
  1074. -- @function lua_util.callback_from_string(str)
  1075. -- Converts a string like `return function(...) end` to lua function and return true and this function
  1076. -- or returns false + error message
  1077. -- @return status code and function object or an error message
  1078. --]]]
  1079. exports.callback_from_string = function(s)
  1080. local loadstring = loadstring or load
  1081. if not s or #s == 0 then
  1082. return false, 'invalid or empty string'
  1083. end
  1084. s = exports.rspamd_str_trim(s)
  1085. local inp
  1086. if s:match('^return%s*function') then
  1087. -- 'return function', can be evaluated directly
  1088. inp = s
  1089. elseif s:match('^function%s*%(') then
  1090. inp = 'return ' .. s
  1091. else
  1092. -- Just a plain sequence
  1093. inp = 'return function(...)\n' .. s .. '; end'
  1094. end
  1095. local ret, res_or_err = pcall(loadstring(inp))
  1096. if not ret or type(res_or_err) ~= 'function' then
  1097. return false, res_or_err
  1098. end
  1099. return ret, res_or_err
  1100. end
  1101. ---[[[
  1102. -- @function lua_util.keys(t)
  1103. -- Returns all keys from a specific table
  1104. -- @param {table} t input table (or iterator triplet)
  1105. -- @return array of keys
  1106. --]]]
  1107. exports.keys = function(gen, param, state)
  1108. local keys = {}
  1109. local i = 1
  1110. if param then
  1111. for k, _ in fun.iter(gen, param, state) do
  1112. rawset(keys, i, k)
  1113. i = i + 1
  1114. end
  1115. else
  1116. for k, _ in pairs(gen) do
  1117. rawset(keys, i, k)
  1118. i = i + 1
  1119. end
  1120. end
  1121. return keys
  1122. end
  1123. ---[[[
  1124. -- @function lua_util.values(t)
  1125. -- Returns all values from a specific table
  1126. -- @param {table} t input table
  1127. -- @return array of values
  1128. --]]]
  1129. exports.values = function(gen, param, state)
  1130. local values = {}
  1131. local i = 1
  1132. if param then
  1133. for _, v in fun.iter(gen, param, state) do
  1134. rawset(values, i, v)
  1135. i = i + 1
  1136. end
  1137. else
  1138. for _, v in pairs(gen) do
  1139. rawset(values, i, v)
  1140. i = i + 1
  1141. end
  1142. end
  1143. return values
  1144. end
  1145. ---[[[
  1146. -- @function lua_util.distance_sorted(t1, t2)
  1147. -- Returns distance between two sorted tables t1 and t2
  1148. -- @param {table} t1 input table
  1149. -- @param {table} t2 input table
  1150. -- @return distance between `t1` and `t2`
  1151. --]]]
  1152. exports.distance_sorted = function(t1, t2)
  1153. local ncomp = #t1
  1154. local ndiff = 0
  1155. local i, j = 1, 1
  1156. if ncomp < #t2 then
  1157. ncomp = #t2
  1158. end
  1159. for _ = 1, ncomp do
  1160. if j > #t2 then
  1161. ndiff = ndiff + ncomp - #t2
  1162. if i > j then
  1163. ndiff = ndiff - (i - j)
  1164. end
  1165. break
  1166. elseif i > #t1 then
  1167. ndiff = ndiff + ncomp - #t1
  1168. if j > i then
  1169. ndiff = ndiff - (j - i)
  1170. end
  1171. break
  1172. end
  1173. if t1[i] == t2[j] then
  1174. i = i + 1
  1175. j = j + 1
  1176. elseif t1[i] < t2[j] then
  1177. i = i + 1
  1178. ndiff = ndiff + 1
  1179. else
  1180. j = j + 1
  1181. ndiff = ndiff + 1
  1182. end
  1183. end
  1184. return ndiff
  1185. end
  1186. ---[[[
  1187. -- @function lua_util.table_digest(t)
  1188. -- Returns hash of all values if t[1] is string or all keys/values otherwise
  1189. -- @param {table} t input array or map
  1190. -- @return {string} base32 representation of blake2b hash of all strings
  1191. --]]]
  1192. local function table_digest(t)
  1193. local cr = require "rspamd_cryptobox_hash"
  1194. local h = cr.create()
  1195. if t[1] then
  1196. for _, e in ipairs(t) do
  1197. if type(e) == 'table' then
  1198. h:update(table_digest(e))
  1199. else
  1200. h:update(tostring(e))
  1201. end
  1202. end
  1203. else
  1204. for k, v in pairs(t) do
  1205. h:update(tostring(k))
  1206. if type(v) == 'string' then
  1207. h:update(v)
  1208. elseif type(v) == 'table' then
  1209. h:update(table_digest(v))
  1210. end
  1211. end
  1212. end
  1213. return h:base32()
  1214. end
  1215. exports.table_digest = table_digest
  1216. ---[[[
  1217. -- @function lua_util.toboolean(v)
  1218. -- Converts a string or a number to boolean
  1219. -- @param {string|number} v
  1220. -- @return {boolean} v converted to boolean
  1221. --]]]
  1222. exports.toboolean = function(v)
  1223. local true_t = {
  1224. ['1'] = true,
  1225. ['true'] = true,
  1226. ['TRUE'] = true,
  1227. ['True'] = true,
  1228. };
  1229. local false_t = {
  1230. ['0'] = false,
  1231. ['false'] = false,
  1232. ['FALSE'] = false,
  1233. ['False'] = false,
  1234. };
  1235. if type(v) == 'string' then
  1236. if true_t[v] == true then
  1237. return true;
  1238. elseif false_t[v] == false then
  1239. return false;
  1240. else
  1241. return false, string.format('cannot convert %q to boolean', v);
  1242. end
  1243. elseif type(v) == 'number' then
  1244. return v ~= 0
  1245. else
  1246. return false, string.format('cannot convert %q to boolean', v);
  1247. end
  1248. end
  1249. ---[[[
  1250. -- @function lua_util.config_check_local_or_authed(config, modname)
  1251. -- Reads check_local and check_authed from the config as this is used in many modules
  1252. -- @param {rspamd_config} config `rspamd_config` global
  1253. -- @param {name} module name
  1254. -- @return {boolean} v converted to boolean
  1255. --]]]
  1256. exports.config_check_local_or_authed = function(rspamd_config, modname, def_local, def_authed)
  1257. local check_local = def_local or false
  1258. local check_authed = def_authed or false
  1259. local function try_section(where)
  1260. local ret = false
  1261. local opts = rspamd_config:get_all_opt(where)
  1262. if type(opts) == 'table' then
  1263. if type(opts['check_local']) == 'boolean' then
  1264. check_local = opts['check_local']
  1265. ret = true
  1266. end
  1267. if type(opts['check_authed']) == 'boolean' then
  1268. check_authed = opts['check_authed']
  1269. ret = true
  1270. end
  1271. end
  1272. return ret
  1273. end
  1274. if not try_section(modname) then
  1275. try_section('options')
  1276. end
  1277. return { check_local, check_authed }
  1278. end
  1279. ---[[[
  1280. -- @function lua_util.is_skip_local_or_authed(task, conf[, ip])
  1281. -- Returns `true` if local or authenticated task should be skipped for this module
  1282. -- @param {rspamd_task} task
  1283. -- @param {table} conf table returned from `config_check_local_or_authed`
  1284. -- @param {rspamd_ip} ip optional ip address (can be obtained from a task)
  1285. -- @return {boolean} true if check should be skipped
  1286. --]]]
  1287. exports.is_skip_local_or_authed = function(task, conf, ip)
  1288. if not ip then
  1289. ip = task:get_from_ip()
  1290. end
  1291. if not conf then
  1292. conf = { false, false }
  1293. end
  1294. if ((not conf[2] and task:get_user()) or
  1295. (not conf[1] and type(ip) == 'userdata' and ip:is_local())) then
  1296. return true
  1297. end
  1298. return false
  1299. end
  1300. ---[[[
  1301. -- @function lua_util.maybe_smtp_quote_value(str)
  1302. -- Checks string for the forbidden elements (tspecials in RFC and quote string if needed)
  1303. -- @param {string} str input string
  1304. -- @return {string} original or quoted string
  1305. --]]]
  1306. local tspecial = lpeg.S "()<>,;:\\\"/[]?= \t\v"
  1307. local special_match = lpeg.P((1 - tspecial) ^ 0 * tspecial ^ 1)
  1308. exports.maybe_smtp_quote_value = function(str)
  1309. if special_match:match(str) then
  1310. return string.format('"%s"', str:gsub('"', '\\"'))
  1311. end
  1312. return str
  1313. end
  1314. ---[[[
  1315. -- @function lua_util.shuffle(table)
  1316. -- Performs in-place shuffling of a table
  1317. -- @param {table} tbl table to shuffle
  1318. -- @return {table} same table
  1319. --]]]
  1320. exports.shuffle = function(tbl)
  1321. local size = #tbl
  1322. for i = size, 1, -1 do
  1323. local rand = math.random(size)
  1324. tbl[i], tbl[rand] = tbl[rand], tbl[i]
  1325. end
  1326. return tbl
  1327. end
  1328. --
  1329. local hex_table = {}
  1330. for idx = 0, 255 do
  1331. hex_table[("%02X"):format(idx)] = string.char(idx)
  1332. hex_table[("%02x"):format(idx)] = string.char(idx)
  1333. end
  1334. ---[[[
  1335. -- @function lua_util.unhex(str)
  1336. -- Decode hex encoded string
  1337. -- @param {string} str string to decode
  1338. -- @return {string} hex decoded string (valid hex pairs are decoded, everything else is printed as is)
  1339. --]]]
  1340. exports.unhex = function(str)
  1341. return str:gsub('(..)', hex_table)
  1342. end
  1343. local http_upstream_lists = {}
  1344. local function http_upstreams_by_url(pool, url)
  1345. local rspamd_url = require "rspamd_url"
  1346. local cached = http_upstream_lists[url]
  1347. if cached then
  1348. return cached
  1349. end
  1350. local real_url = rspamd_url.create(pool, url)
  1351. if not real_url then
  1352. return nil
  1353. end
  1354. local host = real_url:get_host()
  1355. local proto = real_url:get_protocol() or 'http'
  1356. local port = real_url:get_port() or (proto == 'https' and 443 or 80)
  1357. local upstream_list = require "rspamd_upstream_list"
  1358. local upstreams = upstream_list.create(host, port)
  1359. if upstreams then
  1360. http_upstream_lists[url] = upstreams
  1361. return upstreams
  1362. end
  1363. return nil
  1364. end
  1365. ---[[[
  1366. -- @function lua_util.http_upstreams_by_url(pool, url)
  1367. -- Returns a cached or new upstreams list that corresponds to the specific url
  1368. -- @param {mempool} pool memory pool to use (typically static pool from rspamd_config)
  1369. -- @param {string} url full url
  1370. -- @return {upstreams_list} object to get upstream from an url
  1371. --]]]
  1372. exports.http_upstreams_by_url = http_upstreams_by_url
  1373. ---[[[
  1374. -- @function lua_util.dns_timeout_augmentation(cfg)
  1375. -- Returns an augmentation suitable to define DNS timeout for a module
  1376. -- @return {string} a string in format 'timeout=x' where `x` is a number of seconds for DNS timeout
  1377. --]]]
  1378. local function dns_timeout_augmentation(cfg)
  1379. return string.format('timeout=%f', cfg:get_dns_timeout() or 0.0)
  1380. end
  1381. exports.dns_timeout_augmentation = dns_timeout_augmentation
  1382. ---[[[
  1383. --- @function lua_util.strip_lua_comments(lua_code)
  1384. -- Strips single-line and multi-line comments from a given Lua code string and removes
  1385. -- any extra spaces or newlines.
  1386. --
  1387. -- @param lua_code The Lua code string to strip comments from.
  1388. -- @return The resulting Lua code string with comments and extra spaces removed.
  1389. --
  1390. ---]]]
  1391. local function strip_lua_comments(lua_code)
  1392. -- Remove single-line comments
  1393. lua_code = lua_code:gsub("%-%-[^\r\n]*", "")
  1394. -- Remove multi-line comments
  1395. lua_code = lua_code:gsub("%-%-%[%[.-%]%]", "")
  1396. -- Remove extra spaces and newlines
  1397. lua_code = lua_code:gsub("%s+", " ")
  1398. return lua_code
  1399. end
  1400. exports.strip_lua_comments = strip_lua_comments
  1401. ---[[[
  1402. -- @function lua_util.join_path(...)
  1403. -- Joins path components into a single path string using the appropriate separator
  1404. -- for the current operating system.
  1405. --
  1406. -- @param ... Any number of path components to join together.
  1407. -- @return A single path string, with components separated by the appropriate separator.
  1408. --
  1409. ---]]]
  1410. local path_sep = package.config:sub(1, 1) or '/'
  1411. local function join_path(...)
  1412. local components = { ... }
  1413. -- Join components using separator
  1414. return table.concat(components, path_sep)
  1415. end
  1416. exports.join_path = join_path
  1417. -- Short unit test for sanity
  1418. if path_sep == '/' then
  1419. assert(join_path('/path', 'to', 'file') == '/path/to/file')
  1420. else
  1421. assert(join_path('C:', 'path', 'to', 'file') == 'C:\\path\\to\\file')
  1422. end
  1423. -- Defines symbols priorities for common usage in prefilters/postfilters
  1424. exports.symbols_priorities = {
  1425. top = 10, -- Symbols must be executed first (or last), such as settings
  1426. high = 9, -- Example: asn
  1427. medium = 5, -- Everything should use this as default
  1428. low = 0,
  1429. }
  1430. return exports