You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_util.lua 28KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115
  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_util
  15. -- This module contains utility functions for working with Lua and/or Rspamd
  16. --]]
  17. local exports = {}
  18. local lpeg = require 'lpeg'
  19. local rspamd_util = require "rspamd_util"
  20. local fun = require "fun"
  21. local lupa = require "lupa"
  22. local split_grammar = {}
  23. local spaces_split_grammar
  24. local space = lpeg.S' \t\n\v\f\r'
  25. local nospace = 1 - space
  26. local ptrim = space^0 * lpeg.C((space^0 * nospace^1)^0)
  27. local match = lpeg.match
  28. lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', {
  29. keep_trailing_newline = true,
  30. autoescape = false,
  31. })
  32. lupa.filters.pbkdf = function(s)
  33. local cr = require "rspamd_cryptobox"
  34. return cr.pbkdf(s)
  35. end
  36. local function rspamd_str_split(s, sep)
  37. local gr
  38. if not sep then
  39. if not spaces_split_grammar then
  40. local _sep = space
  41. local elem = lpeg.C((1 - _sep)^0)
  42. local p = lpeg.Ct(elem * (_sep * elem)^0)
  43. spaces_split_grammar = p
  44. end
  45. gr = spaces_split_grammar
  46. else
  47. gr = split_grammar[sep]
  48. if not gr then
  49. local _sep
  50. if type(sep) == 'string' then
  51. _sep = lpeg.S(sep) -- Assume set
  52. else
  53. _sep = sep -- Assume lpeg object
  54. end
  55. local elem = lpeg.C((1 - _sep)^0)
  56. local p = lpeg.Ct(elem * (_sep * elem)^0)
  57. gr = p
  58. split_grammar[sep] = gr
  59. end
  60. end
  61. return gr:match(s)
  62. end
  63. --[[[
  64. -- @function lua_util.str_split(text, deliminator)
  65. -- Splits text into a numeric table by deliminator
  66. -- @param {string} text deliminated text
  67. -- @param {string} deliminator the deliminator
  68. -- @return {table} numeric table containing string parts
  69. --]]
  70. exports.rspamd_str_split = rspamd_str_split
  71. exports.str_split = rspamd_str_split
  72. exports.rspamd_str_trim = function(s)
  73. return match(ptrim, s)
  74. end
  75. --[[[
  76. -- @function lua_util.round(number, decimalPlaces)
  77. -- Round number to fixed number of decimal points
  78. -- @param {number} number number to round
  79. -- @param {number} decimalPlaces number of decimal points
  80. -- @return {number} rounded number
  81. --]]
  82. -- Robert Jay Gould http://lua-users.org/wiki/SimpleRound
  83. exports.round = function(num, numDecimalPlaces)
  84. local mult = 10^(numDecimalPlaces or 0)
  85. return math.floor(num * mult) / mult
  86. end
  87. --[[[
  88. -- @function lua_util.template(text, replacements)
  89. -- Replaces values in a text template
  90. -- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces.
  91. -- @param {string} text text containing variables
  92. -- @param {table} replacements key/value pairs for replacements
  93. -- @return {string} string containing replaced values
  94. -- @example
  95. -- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  96. -- -- goop contains "HELLO LUA WORLD!"
  97. --]]
  98. exports.template = function(tmpl, keys)
  99. local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
  100. local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
  101. local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
  102. local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
  103. return lpeg.match(template_grammar, tmpl)
  104. end
  105. local function enrich_template_with_globals(env)
  106. local newenv = exports.shallowcopy(env)
  107. newenv.paths = rspamd_paths
  108. newenv.env = rspamd_env
  109. return newenv
  110. end
  111. --[[[
  112. -- @function lua_util.jinja_template(text, env[, skip_global_env])
  113. -- Replaces values in a text template according to jinja2 syntax
  114. -- @param {string} text text containing variables
  115. -- @param {table} replacements key/value pairs for replacements
  116. -- @param {boolean} skip_global_env don't export Rspamd superglobals
  117. -- @return {string} string containing replaced values
  118. -- @example
  119. -- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  120. -- "HELLO LUA WORLD!"
  121. --]]
  122. exports.jinja_template = function(text, env, skip_global_env)
  123. if not skip_global_env then
  124. env = enrich_template_with_globals(env)
  125. end
  126. return lupa.expand(text, env)
  127. end
  128. --[[[
  129. -- @function lua_util.jinja_file(filename, env[, skip_global_env])
  130. -- Replaces values in a text template according to jinja2 syntax
  131. -- @param {string} filename name of file to expand
  132. -- @param {table} replacements key/value pairs for replacements
  133. -- @param {boolean} skip_global_env don't export Rspamd superglobals
  134. -- @return {string} string containing replaced values
  135. -- @example
  136. -- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  137. -- "HELLO LUA WORLD!"
  138. --]]
  139. exports.jinja_template_file = function(filename, env, skip_global_env)
  140. if not skip_global_env then
  141. env = enrich_template_with_globals(env)
  142. end
  143. return lupa.expand_file(filename, env)
  144. end
  145. exports.remove_email_aliases = function(email_addr)
  146. local function check_gmail_user(addr)
  147. -- Remove all points
  148. local no_dots_user = string.gsub(addr.user, '%.', '')
  149. local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$')
  150. if cap then
  151. return cap, rspamd_str_split(pluses, '+'), nil
  152. elseif no_dots_user ~= addr.user then
  153. return no_dots_user,{},nil
  154. end
  155. return nil
  156. end
  157. local function check_address(addr)
  158. if addr.user then
  159. local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$')
  160. if cap then
  161. return cap, rspamd_str_split(pluses, '+'), nil
  162. end
  163. end
  164. return nil
  165. end
  166. local function set_addr(addr, new_user, new_domain)
  167. if new_user then
  168. addr.user = new_user
  169. end
  170. if new_domain then
  171. addr.domain = new_domain
  172. end
  173. if addr.domain then
  174. addr.addr = string.format('%s@%s', addr.user, addr.domain)
  175. else
  176. addr.addr = string.format('%s@', addr.user)
  177. end
  178. if addr.name and #addr.name > 0 then
  179. addr.raw = string.format('"%s" <%s>', addr.name, addr.addr)
  180. else
  181. addr.raw = string.format('<%s>', addr.addr)
  182. end
  183. end
  184. local function check_gmail(addr)
  185. local nu, tags, nd = check_gmail_user(addr)
  186. if nu then
  187. return nu, tags, nd
  188. end
  189. return nil
  190. end
  191. local function check_googlemail(addr)
  192. local nd = 'gmail.com'
  193. local nu, tags = check_gmail_user(addr)
  194. if nu then
  195. return nu, tags, nd
  196. end
  197. return nil, nil, nd
  198. end
  199. local specific_domains = {
  200. ['gmail.com'] = check_gmail,
  201. ['googlemail.com'] = check_googlemail,
  202. }
  203. if email_addr then
  204. if email_addr.domain and specific_domains[email_addr.domain] then
  205. local nu, tags, nd = specific_domains[email_addr.domain](email_addr)
  206. if nu or nd then
  207. set_addr(email_addr, nu, nd)
  208. return nu, tags
  209. end
  210. else
  211. local nu, tags, nd = check_address(email_addr)
  212. if nu or nd then
  213. set_addr(email_addr, nu, nd)
  214. return nu, tags
  215. end
  216. end
  217. return nil
  218. end
  219. end
  220. exports.is_rspamc_or_controller = function(task)
  221. local ua = task:get_request_header('User-Agent') or ''
  222. local pwd = task:get_request_header('Password')
  223. local is_rspamc = false
  224. if tostring(ua) == 'rspamc' or pwd then is_rspamc = true end
  225. return is_rspamc
  226. end
  227. --[[[
  228. -- @function lua_util.unpack(table)
  229. -- Converts numeric table to varargs
  230. -- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3
  231. -- @param {table} table numerically indexed table to unpack
  232. -- @return {varargs} unpacked table elements
  233. --]]
  234. local unpack_function = table.unpack or unpack
  235. exports.unpack = function(t)
  236. return unpack_function(t)
  237. end
  238. --[[[
  239. -- @function lua_util.spairs(table)
  240. -- Like `pairs` but keys are sorted lexicographically
  241. -- @param {table} table table containing key/value pairs
  242. -- @return {function} generator function returning key/value pairs
  243. --]]
  244. -- Sorted iteration:
  245. -- for k,v in spairs(t) do ... end
  246. --
  247. -- or with custom comparison:
  248. -- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end)
  249. --
  250. -- optional limit is also available (e.g. return top X elements)
  251. local function spairs(t, order, lim)
  252. -- collect the keys
  253. local keys = {}
  254. for k in pairs(t) do keys[#keys+1] = k end
  255. -- if order function given, sort by it by passing the table and keys a, b,
  256. -- otherwise just sort the keys
  257. if order then
  258. table.sort(keys, function(a,b) return order(t, a, b) end)
  259. else
  260. table.sort(keys)
  261. end
  262. -- return the iterator function
  263. local i = 0
  264. return function()
  265. i = i + 1
  266. if not lim or i <= lim then
  267. if keys[i] then
  268. return keys[i], t[keys[i]]
  269. end
  270. end
  271. end
  272. end
  273. exports.spairs = spairs
  274. --[[[
  275. -- @function lua_util.disable_module(modname, how)
  276. -- Disables a plugin
  277. -- @param {string} modname name of plugin to disable
  278. -- @param {string} how 'redis' to disable redis, 'config' to disable startup
  279. --]]
  280. local function disable_module(modname, how)
  281. if rspamd_plugins_state.enabled[modname] then
  282. rspamd_plugins_state.enabled[modname] = nil
  283. end
  284. if how == 'redis' then
  285. rspamd_plugins_state.disabled_redis[modname] = {}
  286. elseif how == 'config' then
  287. rspamd_plugins_state.disabled_unconfigured[modname] = {}
  288. elseif how == 'experimental' then
  289. rspamd_plugins_state.disabled_experimental[modname] = {}
  290. else
  291. rspamd_plugins_state.disabled_failed[modname] = {}
  292. end
  293. end
  294. exports.disable_module = disable_module
  295. --[[[
  296. -- @function lua_util.disable_module(modname)
  297. -- Checks experimental plugins state and disable if needed
  298. -- @param {string} modname name of plugin to check
  299. -- @return {boolean} true if plugin should be enabled, false otherwise
  300. --]]
  301. local function check_experimental(modname)
  302. if rspamd_config:experimental_enabled() then
  303. return true
  304. else
  305. disable_module(modname, 'experimental')
  306. end
  307. return false
  308. end
  309. exports.check_experimental = check_experimental
  310. --[[[
  311. -- @function lua_util.list_to_hash(list)
  312. -- Converts numerically-indexed table to table indexed by values
  313. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  314. -- @return {table} table indexed by values
  315. -- @example
  316. -- local h = lua_util.list_to_hash({"a", "b"})
  317. -- -- h contains {a = true, b = true}
  318. --]]
  319. local function list_to_hash(list)
  320. if type(list) == 'table' then
  321. if list[1] then
  322. local h = {}
  323. for _, e in ipairs(list) do
  324. h[e] = true
  325. end
  326. return h
  327. else
  328. return list
  329. end
  330. elseif type(list) == 'string' then
  331. local h = {}
  332. h[list] = true
  333. return h
  334. end
  335. end
  336. exports.list_to_hash = list_to_hash
  337. --[[[
  338. -- @function lua_util.parse_time_interval(str)
  339. -- Parses human readable time interval
  340. -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days,
  341. -- 'w' for weeks, 'y' for years
  342. -- @param {string} str input string
  343. -- @return {number|nil} parsed interval as seconds (might be fractional)
  344. --]]
  345. local function parse_time_interval(str)
  346. local function parse_time_suffix(s)
  347. if s == 's' then
  348. return 1
  349. elseif s == 'm' then
  350. return 60
  351. elseif s == 'h' then
  352. return 3600
  353. elseif s == 'd' then
  354. return 86400
  355. elseif s == 'w' then
  356. return 86400 * 7
  357. elseif s == 'y' then
  358. return 365 * 86400;
  359. end
  360. end
  361. local digit = lpeg.R("09")
  362. local parser = {}
  363. parser.integer =
  364. (lpeg.S("+-") ^ -1) *
  365. (digit ^ 1)
  366. parser.fractional =
  367. (lpeg.P(".") ) *
  368. (digit ^ 1)
  369. parser.number =
  370. (parser.integer *
  371. (parser.fractional ^ -1)) +
  372. (lpeg.S("+-") * parser.fractional)
  373. parser.time = lpeg.Cf(lpeg.Cc(1) *
  374. (parser.number / tonumber) *
  375. ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1),
  376. function (acc, val) return acc * val end)
  377. local t = lpeg.match(parser.time, str)
  378. return t
  379. end
  380. exports.parse_time_interval = parse_time_interval
  381. --[[[
  382. -- @function lua_util.dehumanize_number(str)
  383. -- Parses human readable number
  384. -- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier,
  385. -- e.g. `10mb` equal to `10 * 1024 * 1024`
  386. -- @param {string} str input string
  387. -- @return {number|nil} parsed number
  388. --]]
  389. local function dehumanize_number(str)
  390. local function parse_suffix(s)
  391. if s == 'k' then
  392. return 1000
  393. elseif s == 'm' then
  394. return 1000000
  395. elseif s == 'g' then
  396. return 1e9
  397. elseif s == 'kb' then
  398. return 1024
  399. elseif s == 'mb' then
  400. return 1024 * 1024
  401. elseif s == 'gb' then
  402. return 1024 * 1024;
  403. end
  404. end
  405. local digit = lpeg.R("09")
  406. local parser = {}
  407. parser.integer =
  408. (lpeg.S("+-") ^ -1) *
  409. (digit ^ 1)
  410. parser.fractional =
  411. (lpeg.P(".") ) *
  412. (digit ^ 1)
  413. parser.number =
  414. (parser.integer *
  415. (parser.fractional ^ -1)) +
  416. (lpeg.S("+-") * parser.fractional)
  417. parser.humanized_number = lpeg.Cf(lpeg.Cc(1) *
  418. (parser.number / tonumber) *
  419. (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1),
  420. function (acc, val) return acc * val end)
  421. local t = lpeg.match(parser.humanized_number, str)
  422. return t
  423. end
  424. exports.dehumanize_number = dehumanize_number
  425. --[[[
  426. -- @function lua_util.table_cmp(t1, t2)
  427. -- Compare two tables deeply
  428. --]]
  429. local function table_cmp(table1, table2)
  430. local avoid_loops = {}
  431. local function recurse(t1, t2)
  432. if type(t1) ~= type(t2) then return false end
  433. if type(t1) ~= "table" then return t1 == t2 end
  434. if avoid_loops[t1] then return avoid_loops[t1] == t2 end
  435. avoid_loops[t1] = t2
  436. -- Copy keys from t2
  437. local t2keys = {}
  438. local t2tablekeys = {}
  439. for k, _ in pairs(t2) do
  440. if type(k) == "table" then table.insert(t2tablekeys, k) end
  441. t2keys[k] = true
  442. end
  443. -- Let's iterate keys from t1
  444. for k1, v1 in pairs(t1) do
  445. local v2 = t2[k1]
  446. if type(k1) == "table" then
  447. -- if key is a table, we need to find an equivalent one.
  448. local ok = false
  449. for i, tk in ipairs(t2tablekeys) do
  450. if table_cmp(k1, tk) and recurse(v1, t2[tk]) then
  451. table.remove(t2tablekeys, i)
  452. t2keys[tk] = nil
  453. ok = true
  454. break
  455. end
  456. end
  457. if not ok then return false end
  458. else
  459. -- t1 has a key which t2 doesn't have, fail.
  460. if v2 == nil then return false end
  461. t2keys[k1] = nil
  462. if not recurse(v1, v2) then return false end
  463. end
  464. end
  465. -- if t2 has a key which t1 doesn't have, fail.
  466. if next(t2keys) then return false end
  467. return true
  468. end
  469. return recurse(table1, table2)
  470. end
  471. exports.table_cmp = table_cmp
  472. --[[[
  473. -- @function lua_util.table_cmp(task, name, value, stop_chars)
  474. -- Performs header folding
  475. --]]
  476. exports.fold_header = function(task, name, value, stop_chars)
  477. local how
  478. if task:has_flag("milter") then
  479. how = "lf"
  480. else
  481. how = task:get_newlines_type()
  482. end
  483. return rspamd_util.fold_header(name, value, how, stop_chars)
  484. end
  485. --[[[
  486. -- @function lua_util.override_defaults(defaults, override)
  487. -- Overrides values from defaults with override
  488. --]]
  489. local function override_defaults(def, override)
  490. -- Corner cases
  491. if not override or type(override) ~= 'table' then
  492. return def
  493. end
  494. if not def or type(def) ~= 'table' then
  495. return override
  496. end
  497. local res = {}
  498. for k,v in pairs(override) do
  499. if type(v) == 'table' then
  500. if def[k] and type(def[k]) == 'table' then
  501. -- Recursively override elements
  502. res[k] = override_defaults(def[k], v)
  503. else
  504. res[k] = v
  505. end
  506. else
  507. res[k] = v
  508. end
  509. end
  510. for k,v in pairs(def) do
  511. if type(res[k]) == 'nil' then
  512. res[k] = v
  513. end
  514. end
  515. return res
  516. end
  517. exports.override_defaults = override_defaults
  518. --[[[
  519. -- @function lua_util.extract_specific_urls(params)
  520. -- params: {
  521. - - task
  522. - - limit <int> (default = 9999)
  523. - - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
  524. works only if number of unique eSLD less than `limit`
  525. - - need_emails <bool> (default = false)
  526. - - filter <callback> (default = nil)
  527. - - prefix <string> cache prefix (default = nil)
  528. -- }
  529. -- Apply heuristic in extracting of urls from task, this function
  530. -- tries its best to extract specific number of urls from a task based on
  531. -- their characteristics
  532. --]]
  533. -- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
  534. exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix)
  535. local default_params = {
  536. limit = 9999,
  537. esld_limit = 9999,
  538. need_emails = false,
  539. filter = nil,
  540. prefix = nil
  541. }
  542. local params
  543. if type(params_or_task) == 'table' and type(lim) == 'nil' then
  544. params = params_or_task
  545. else
  546. -- Deprecated call
  547. params = {
  548. task = params_or_task,
  549. limit = lim,
  550. need_emails = need_emails,
  551. filter = filter,
  552. prefix = prefix
  553. }
  554. end
  555. for k,v in pairs(default_params) do
  556. if not params[k] then params[k] = v end
  557. end
  558. local cache_key
  559. if params.prefix then
  560. cache_key = params.prefix
  561. else
  562. cache_key = string.format('sp_urls_%d%s', params.limit,
  563. tostring(params.need_emails))
  564. end
  565. local cached = params.task:cache_get(cache_key)
  566. if cached then
  567. return cached
  568. end
  569. local urls = params.task:get_urls(params.need_emails)
  570. if not urls then return {} end
  571. if params.filter then urls = fun.totable(fun.filter(params.filter, urls)) end
  572. if #urls <= params.limit and #urls <= params.esld_limit then
  573. params.task:cache_set(cache_key, urls)
  574. return urls
  575. end
  576. -- Filter by tld:
  577. local tlds = {}
  578. local eslds = {}
  579. local ntlds, neslds = 0, 0
  580. local res = {}
  581. for _,u in ipairs(urls) do
  582. local esld = u:get_tld()
  583. if esld then
  584. if not eslds[esld] then
  585. eslds[esld] = {u}
  586. neslds = neslds + 1
  587. else
  588. if #eslds[esld] < params.esld_limit then
  589. table.insert(eslds[esld], u)
  590. end
  591. end
  592. local parts = rspamd_str_split(esld, '.')
  593. local tld = table.concat(fun.totable(fun.tail(parts)), '.')
  594. if not tlds[tld] then
  595. tlds[tld] = {u}
  596. ntlds = ntlds + 1
  597. else
  598. table.insert(tlds[tld], u)
  599. end
  600. -- Extract priority urls that are proven to be malicious
  601. if not u:is_html_displayed() then
  602. if u:is_obscured() then
  603. table.insert(res, u)
  604. else
  605. if u:get_user() then
  606. table.insert(res, u)
  607. elseif u:is_subject() or u:is_phished() then
  608. table.insert(res, u)
  609. end
  610. end
  611. end
  612. end
  613. end
  614. local limit = params.limit
  615. limit = limit - #res
  616. if limit <= 0 then limit = 1 end
  617. if neslds <= limit then
  618. -- We can get urls based on their eslds
  619. repeat
  620. local item_found = false
  621. for _,lurls in pairs(eslds) do
  622. if #lurls > 0 then
  623. table.insert(res, table.remove(lurls))
  624. limit = limit - 1
  625. item_found = true
  626. end
  627. end
  628. until limit <= 0 or not item_found
  629. params.task:cache_set(cache_key, urls)
  630. return res
  631. end
  632. if ntlds <= limit then
  633. while limit > 0 do
  634. for _,lurls in pairs(tlds) do
  635. if #lurls > 0 then
  636. table.insert(res, table.remove(lurls))
  637. limit = limit - 1
  638. end
  639. end
  640. end
  641. params.task:cache_set(cache_key, urls)
  642. return res
  643. end
  644. -- We need to sort tlds table first
  645. local tlds_keys = {}
  646. for k,_ in pairs(tlds) do table.insert(tlds_keys, k) end
  647. table.sort(tlds_keys, function (t1, t2)
  648. return #tlds[t1] < #tlds[t2]
  649. end)
  650. ntlds = #tlds_keys
  651. for i=1,ntlds / 2 do
  652. local tld1 = tlds[tlds_keys[i]]
  653. local tld2 = tlds[tlds_keys[ntlds - i]]
  654. if #tld1 > 0 then
  655. table.insert(res, table.remove(tld1))
  656. limit = limit - 1
  657. end
  658. if #tld2 > 0 then
  659. table.insert(res, table.remove(tld2))
  660. limit = limit - 1
  661. end
  662. if limit <= 0 then
  663. break
  664. end
  665. end
  666. params.task:cache_set(cache_key, urls)
  667. return res
  668. end
  669. --[[[
  670. -- @function lua_util.deepcopy(table)
  671. -- params: {
  672. - - table
  673. -- }
  674. -- Performs deep copy of the table. Including metatables
  675. --]]
  676. local function deepcopy(orig)
  677. local orig_type = type(orig)
  678. local copy
  679. if orig_type == 'table' then
  680. copy = {}
  681. for orig_key, orig_value in next, orig, nil do
  682. copy[deepcopy(orig_key)] = deepcopy(orig_value)
  683. end
  684. setmetatable(copy, deepcopy(getmetatable(orig)))
  685. else -- number, string, boolean, etc
  686. copy = orig
  687. end
  688. return copy
  689. end
  690. exports.deepcopy = deepcopy
  691. --[[[
  692. -- @function lua_util.shallowcopy(tbl)
  693. -- Performs shallow (and fast) copy of a table or another Lua type
  694. --]]
  695. exports.shallowcopy = function(orig)
  696. local orig_type = type(orig)
  697. local copy
  698. if orig_type == 'table' then
  699. copy = {}
  700. for orig_key, orig_value in pairs(orig) do
  701. copy[orig_key] = orig_value
  702. end
  703. else
  704. copy = orig
  705. end
  706. return copy
  707. end
  708. -- Debugging support
  709. local unconditional_debug = false
  710. local debug_modules = {}
  711. local debug_aliases = {}
  712. local log_level = 384 -- debug + forced (1 << 7 | 1 << 8)
  713. exports.init_debug_logging = function(config)
  714. local logger = require "rspamd_logger"
  715. -- Fill debug modules from the config
  716. local logging = config:get_all_opt('logging')
  717. if logging then
  718. local log_level_str = logging.level
  719. if log_level_str then
  720. if log_level_str == 'debug' then
  721. unconditional_debug = true
  722. end
  723. end
  724. if not unconditional_debug then
  725. if logging.debug_modules then
  726. for _,m in ipairs(logging.debug_modules) do
  727. debug_modules[m] = true
  728. logger.infox(config, 'enable debug for Lua module %s', m)
  729. end
  730. end
  731. if #debug_aliases > 0 then
  732. for alias,mod in pairs(debug_aliases) do
  733. if debug_modules[mod] then
  734. debug_modules[alias] = true
  735. logger.infox(config, 'enable debug for Lua module %s (%s aliased)',
  736. alias, mod)
  737. end
  738. end
  739. end
  740. end
  741. end
  742. end
  743. --[[[
  744. -- @function lua_util.debugm(module, [log_object], format, ...)
  745. -- Performs fast debug log for a specific module
  746. --]]
  747. exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...)
  748. local logger = require "rspamd_logger"
  749. if unconditional_debug or debug_modules[mod] then
  750. if type(obj_or_fmt) == 'string' then
  751. logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...)
  752. else
  753. logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...)
  754. end
  755. end
  756. end
  757. --[[[
  758. -- @function lua_util.add_debug_alias(mod, alias)
  759. -- Add debugging alias so logging to `alias` will be treated as logging to `mod`
  760. --]]
  761. exports.add_debug_alias = function(mod, alias)
  762. local logger = require "rspamd_logger"
  763. debug_aliases[alias] = mod
  764. if debug_modules[mod] then
  765. debug_modules[alias] = true
  766. logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)',
  767. alias, mod)
  768. end
  769. end
  770. ---[[[
  771. -- @function lua_util.get_task_verdict(task)
  772. -- Returns verdict for a task + score if certain, must be called from idempotent filters only
  773. -- Returns string:
  774. -- * `spam`: if message have over reject threshold and has more than one positive rule
  775. -- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules
  776. -- * `passthrough`: if a message has been passed through some short-circuit rule
  777. -- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score
  778. -- * `uncertain`: all other cases
  779. --]]
  780. exports.get_task_verdict = function(task)
  781. local result = task:get_metric_result()
  782. if result then
  783. if result.passthrough then
  784. return 'passthrough',nil
  785. end
  786. local score = result.score
  787. local action = result.action
  788. if action == 'reject' and result.npositive > 1 then
  789. return 'spam',score
  790. elseif action == 'no action' then
  791. if score < 0 or result.nnegative > 3 then
  792. return 'ham',score
  793. end
  794. else
  795. -- All colors of junk
  796. if action == 'add header' or action == 'rewrite subject' then
  797. if result.npositive > 2 then
  798. return 'junk',score
  799. end
  800. end
  801. end
  802. return 'uncertain',score
  803. end
  804. end
  805. ---[[[
  806. -- @function lua_util.maybe_obfuscate_string(subject, settings, prefix)
  807. -- Obfuscate string if enabled in settings. Also checks utf8 validity - if
  808. -- string is not valid utf8 then '???' is returned. Empty string returned as is.
  809. -- Supported settings:
  810. -- * <prefix>_privacy = false - subject privacy is off
  811. -- * <prefix>_privacy_alg = 'blake2' - default hash-algorithm to obfuscate subject
  812. -- * <prefix>_privacy_prefix = 'obf' - prefix to show it's obfuscated
  813. -- * <prefix>_privacy_length = 16 - cut the length of the hash; if 0 or fasle full hash is returned
  814. -- @return obfuscated or validated subject
  815. --]]
  816. exports.maybe_obfuscate_string = function(subject, settings, prefix)
  817. local hash = require 'rspamd_cryptobox_hash'
  818. if not subject or subject == '' then
  819. return subject
  820. elseif not rspamd_util.is_valid_utf8(subject) then
  821. subject = '???'
  822. elseif settings[prefix .. '_privacy'] then
  823. local hash_alg = settings[prefix .. '_privacy_alg'] or 'blake2'
  824. local subject_hash = hash.create_specific(hash_alg, subject)
  825. local strip_len = settings[prefix .. '_privacy_length']
  826. if strip_len and strip_len > 0 then
  827. subject = subject_hash:hex():sub(1, strip_len)
  828. else
  829. subject = subject_hash:hex()
  830. end
  831. local privacy_prefix = settings[prefix .. '_privacy_prefix']
  832. if privacy_prefix and #privacy_prefix > 0 then
  833. subject = privacy_prefix .. ':' .. subject
  834. end
  835. end
  836. return subject
  837. end
  838. ---[[[
  839. -- @function lua_util.callback_from_string(str)
  840. -- Converts a string like `return function(...) end` to lua function or emits error using
  841. -- `rspamd_config` superglobal
  842. -- @return function object or nil
  843. --]]]
  844. exports.callback_from_string = function(str)
  845. local loadstring = loadstring or load
  846. local ret, res_or_err = pcall(loadstring(str))
  847. if not ret or type(res_or_err) ~= 'function' then
  848. local rspamd_logger = require "rspamd_logger"
  849. rspamd_logger.errx(rspamd_config, 'invalid callback (%s) - must be a function',
  850. res_or_err)
  851. return nil
  852. end
  853. return res_or_err
  854. end
  855. ---[[[
  856. -- @function lua_util.keys(t)
  857. -- Returns all keys from a specific table
  858. -- @param {table} t input table (or iterator triplet)
  859. -- @return array of keys
  860. --]]]
  861. exports.keys = function(gen, param, state)
  862. local keys = {}
  863. local i = 1
  864. if param then
  865. for k,_ in fun.iter(gen, param, state) do
  866. rawset(keys, i, k)
  867. i = i + 1
  868. end
  869. else
  870. for k,_ in pairs(gen) do
  871. rawset(keys, i, k)
  872. i = i + 1
  873. end
  874. end
  875. return keys
  876. end
  877. ---[[[
  878. -- @function lua_util.values(t)
  879. -- Returns all values from a specific table
  880. -- @param {table} t input table
  881. -- @return array of values
  882. --]]]
  883. exports.values = function(gen, param, state)
  884. local values = {}
  885. local i = 1
  886. if param then
  887. for _,v in fun.iter(gen, param, state) do
  888. rawset(values, i, v)
  889. i = i + 1
  890. end
  891. else
  892. for _,v in pairs(gen) do
  893. rawset(values, i, v)
  894. i = i + 1
  895. end
  896. end
  897. return values
  898. end
  899. ---[[[
  900. -- @function lua_util.distance_sorted(t1, t2)
  901. -- Returns distance between two sorted tables t1 and t2
  902. -- @param {table} t1 input table
  903. -- @param {table} t2 input table
  904. -- @return distance between `t1` and `t2`
  905. --]]]
  906. exports.distance_sorted = function(t1, t2)
  907. local ncomp = #t1
  908. local ndiff = 0
  909. local i,j = 1,1
  910. if ncomp < #t2 then
  911. ncomp = #t2
  912. end
  913. for _=1,ncomp do
  914. if j > #t2 then
  915. ndiff = ndiff + ncomp - #t2
  916. if i > j then
  917. ndiff = ndiff - (i - j)
  918. end
  919. break
  920. elseif i > #t1 then
  921. ndiff = ndiff + ncomp - #t1
  922. if j > i then
  923. ndiff = ndiff - (j - i)
  924. end
  925. break
  926. end
  927. if t1[i] == t2[j] then
  928. i = i + 1
  929. j = j + 1
  930. elseif t1[i] < t2[j] then
  931. i = i + 1
  932. ndiff = ndiff + 1
  933. else
  934. j = j + 1
  935. ndiff = ndiff + 1
  936. end
  937. end
  938. return ndiff
  939. end
  940. ---[[[
  941. -- @function lua_util.table_digest(t)
  942. -- Returns hash of all values if t[1] is string or all keys otherwise
  943. -- @param {table} t input array or map
  944. -- @return {string} base32 representation of blake2b hash of all strings
  945. --]]]
  946. exports.table_digest = function(t)
  947. local cr = require "rspamd_cryptobox_hash"
  948. local h = cr.create()
  949. if t[1] then
  950. for _,e in ipairs(t) do
  951. h:update(tostring(e))
  952. end
  953. else
  954. for k,_ in pairs(t) do
  955. h:update(k)
  956. end
  957. end
  958. return h:base32()
  959. end
  960. return exports