You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_util.lua 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_util
  15. -- This module contains utility functions for working with Lua and/or Rspamd
  16. --]]
  17. local exports = {}
  18. local lpeg = require 'lpeg'
  19. local rspamd_util = require "rspamd_util"
  20. local fun = require "fun"
  21. local split_grammar = {}
  22. local spaces_split_grammar
  23. local space = lpeg.S' \t\n\v\f\r'
  24. local nospace = 1 - space
  25. local ptrim = space^0 * lpeg.C((space^0 * nospace^1)^0)
  26. local match = lpeg.match
  27. local function rspamd_str_split(s, sep)
  28. local gr
  29. if not sep then
  30. if not spaces_split_grammar then
  31. local _sep = space
  32. local elem = lpeg.C((1 - _sep)^0)
  33. local p = lpeg.Ct(elem * (_sep * elem)^0)
  34. spaces_split_grammar = p
  35. end
  36. gr = spaces_split_grammar
  37. else
  38. gr = split_grammar[sep]
  39. if not gr then
  40. local _sep
  41. if type(sep) == 'string' then
  42. _sep = lpeg.S(sep) -- Assume set
  43. else
  44. _sep = sep -- Assume lpeg object
  45. end
  46. local elem = lpeg.C((1 - _sep)^0)
  47. local p = lpeg.Ct(elem * (_sep * elem)^0)
  48. gr = p
  49. split_grammar[sep] = gr
  50. end
  51. end
  52. return gr:match(s)
  53. end
  54. --[[[
  55. -- @function lua_util.str_split(text, deliminator)
  56. -- Splits text into a numeric table by deliminator
  57. -- @param {string} text deliminated text
  58. -- @param {string} deliminator the deliminator
  59. -- @return {table} numeric table containing string parts
  60. --]]
  61. exports.rspamd_str_split = rspamd_str_split
  62. exports.str_split = rspamd_str_split
  63. exports.rspamd_str_trim = function(s)
  64. return match(ptrim, s)
  65. end
  66. --[[[
  67. -- @function lua_util.round(number, decimalPlaces)
  68. -- Round number to fixed number of decimal points
  69. -- @param {number} number number to round
  70. -- @param {number} decimalPlaces number of decimal points
  71. -- @return {number} rounded number
  72. --]]
  73. -- Robert Jay Gould http://lua-users.org/wiki/SimpleRound
  74. exports.round = function(num, numDecimalPlaces)
  75. local mult = 10^(numDecimalPlaces or 0)
  76. return math.floor(num * mult) / mult
  77. end
  78. --[[[
  79. -- @function lua_util.template(text, replacements)
  80. -- Replaces values in a text template
  81. -- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces.
  82. -- @param {string} text text containing variables
  83. -- @param {table} replacements key/value pairs for replacements
  84. -- @return {string} string containing replaced values
  85. -- @example
  86. -- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  87. -- -- goop contains "HELLO LUA WORLD!"
  88. --]]
  89. exports.template = function(tmpl, keys)
  90. local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
  91. local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
  92. local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
  93. local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
  94. return lpeg.match(template_grammar, tmpl)
  95. end
  96. exports.remove_email_aliases = function(email_addr)
  97. local function check_gmail_user(addr)
  98. -- Remove all points
  99. local no_dots_user = string.gsub(addr.user, '%.', '')
  100. local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$')
  101. if cap then
  102. return cap, rspamd_str_split(pluses, '+'), nil
  103. elseif no_dots_user ~= addr.user then
  104. return no_dots_user,{},nil
  105. end
  106. return nil
  107. end
  108. local function check_address(addr)
  109. if addr.user then
  110. local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$')
  111. if cap then
  112. return cap, rspamd_str_split(pluses, '+'), nil
  113. end
  114. end
  115. return nil
  116. end
  117. local function set_addr(addr, new_user, new_domain)
  118. if new_user then
  119. addr.user = new_user
  120. end
  121. if new_domain then
  122. addr.domain = new_domain
  123. end
  124. if addr.domain then
  125. addr.addr = string.format('%s@%s', addr.user, addr.domain)
  126. else
  127. addr.addr = string.format('%s@', addr.user)
  128. end
  129. if addr.name and #addr.name > 0 then
  130. addr.raw = string.format('"%s" <%s>', addr.name, addr.addr)
  131. else
  132. addr.raw = string.format('<%s>', addr.addr)
  133. end
  134. end
  135. local function check_gmail(addr)
  136. local nu, tags, nd = check_gmail_user(addr)
  137. if nu then
  138. return nu, tags, nd
  139. end
  140. return nil
  141. end
  142. local function check_googlemail(addr)
  143. local nd = 'gmail.com'
  144. local nu, tags = check_gmail_user(addr)
  145. if nu then
  146. return nu, tags, nd
  147. end
  148. return nil, nil, nd
  149. end
  150. local specific_domains = {
  151. ['gmail.com'] = check_gmail,
  152. ['googlemail.com'] = check_googlemail,
  153. }
  154. if email_addr then
  155. if email_addr.domain and specific_domains[email_addr.domain] then
  156. local nu, tags, nd = specific_domains[email_addr.domain](email_addr)
  157. if nu or nd then
  158. set_addr(email_addr, nu, nd)
  159. return nu, tags
  160. end
  161. else
  162. local nu, tags, nd = check_address(email_addr)
  163. if nu or nd then
  164. set_addr(email_addr, nu, nd)
  165. return nu, tags
  166. end
  167. end
  168. return nil
  169. end
  170. end
  171. exports.is_rspamc_or_controller = function(task)
  172. local ua = task:get_request_header('User-Agent') or ''
  173. local pwd = task:get_request_header('Password')
  174. local is_rspamc = false
  175. if tostring(ua) == 'rspamc' or pwd then is_rspamc = true end
  176. return is_rspamc
  177. end
  178. --[[[
  179. -- @function lua_util.unpack(table)
  180. -- Converts numeric table to varargs
  181. -- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3
  182. -- @param {table} table numerically indexed table to unpack
  183. -- @return {varargs} unpacked table elements
  184. --]]
  185. local unpack_function = table.unpack or unpack
  186. exports.unpack = function(t)
  187. return unpack_function(t)
  188. end
  189. --[[[
  190. -- @function lua_util.spairs(table)
  191. -- Like `pairs` but keys are sorted lexicographically
  192. -- @param {table} table table containing key/value pairs
  193. -- @return {function} generator function returning key/value pairs
  194. --]]
  195. -- Sorted iteration:
  196. -- for k,v in spairs(t) do ... end
  197. --
  198. -- or with custom comparison:
  199. -- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end)
  200. --
  201. -- optional limit is also available (e.g. return top X elements)
  202. local function spairs(t, order, lim)
  203. -- collect the keys
  204. local keys = {}
  205. for k in pairs(t) do keys[#keys+1] = k end
  206. -- if order function given, sort by it by passing the table and keys a, b,
  207. -- otherwise just sort the keys
  208. if order then
  209. table.sort(keys, function(a,b) return order(t, a, b) end)
  210. else
  211. table.sort(keys)
  212. end
  213. -- return the iterator function
  214. local i = 0
  215. return function()
  216. i = i + 1
  217. if not lim or i <= lim then
  218. if keys[i] then
  219. return keys[i], t[keys[i]]
  220. end
  221. end
  222. end
  223. end
  224. exports.spairs = spairs
  225. --[[[
  226. -- @function lua_util.disable_module(modname, how)
  227. -- Disables a plugin
  228. -- @param {string} modname name of plugin to disable
  229. -- @param {string} how 'redis' to disable redis, 'config' to disable startup
  230. --]]
  231. local function disable_module(modname, how)
  232. if rspamd_plugins_state.enabled[modname] then
  233. rspamd_plugins_state.enabled[modname] = nil
  234. end
  235. if how == 'redis' then
  236. rspamd_plugins_state.disabled_redis[modname] = {}
  237. elseif how == 'config' then
  238. rspamd_plugins_state.disabled_unconfigured[modname] = {}
  239. elseif how == 'experimental' then
  240. rspamd_plugins_state.disabled_experimental[modname] = {}
  241. else
  242. rspamd_plugins_state.disabled_failed[modname] = {}
  243. end
  244. end
  245. exports.disable_module = disable_module
  246. --[[[
  247. -- @function lua_util.disable_module(modname)
  248. -- Checks experimental plugins state and disable if needed
  249. -- @param {string} modname name of plugin to check
  250. -- @return {boolean} true if plugin should be enabled, false otherwise
  251. --]]
  252. local function check_experimental(modname)
  253. if rspamd_config:experimental_enabled() then
  254. return true
  255. else
  256. disable_module(modname, 'experimental')
  257. end
  258. return false
  259. end
  260. exports.check_experimental = check_experimental
  261. --[[[
  262. -- @function lua_util.list_to_hash(list)
  263. -- Converts numerically-indexed table to table indexed by values
  264. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  265. -- @return {table} table indexed by values
  266. -- @example
  267. -- local h = lua_util.list_to_hash({"a", "b"})
  268. -- -- h contains {a = true, b = true}
  269. --]]
  270. local function list_to_hash(list)
  271. if type(list) == 'table' then
  272. if list[1] then
  273. local h = {}
  274. for _, e in ipairs(list) do
  275. h[e] = true
  276. end
  277. return h
  278. else
  279. return list
  280. end
  281. elseif type(list) == 'string' then
  282. local h = {}
  283. h[list] = true
  284. return h
  285. end
  286. end
  287. exports.list_to_hash = list_to_hash
  288. --[[[
  289. -- @function lua_util.parse_time_interval(str)
  290. -- Parses human readable time interval
  291. -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days,
  292. -- 'w' for weeks, 'y' for years
  293. -- @param {string} str input string
  294. -- @return {number|nil} parsed interval as seconds (might be fractional)
  295. --]]
  296. local function parse_time_interval(str)
  297. local function parse_time_suffix(s)
  298. if s == 's' then
  299. return 1
  300. elseif s == 'm' then
  301. return 60
  302. elseif s == 'h' then
  303. return 3600
  304. elseif s == 'd' then
  305. return 86400
  306. elseif s == 'w' then
  307. return 86400 * 7
  308. elseif s == 'y' then
  309. return 365 * 86400;
  310. end
  311. end
  312. local digit = lpeg.R("09")
  313. local parser = {}
  314. parser.integer =
  315. (lpeg.S("+-") ^ -1) *
  316. (digit ^ 1)
  317. parser.fractional =
  318. (lpeg.P(".") ) *
  319. (digit ^ 1)
  320. parser.number =
  321. (parser.integer *
  322. (parser.fractional ^ -1)) +
  323. (lpeg.S("+-") * parser.fractional)
  324. parser.time = lpeg.Cf(lpeg.Cc(1) *
  325. (parser.number / tonumber) *
  326. ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1),
  327. function (acc, val) return acc * val end)
  328. local t = lpeg.match(parser.time, str)
  329. return t
  330. end
  331. exports.parse_time_interval = parse_time_interval
  332. --[[[
  333. -- @function lua_util.dehumanize_number(str)
  334. -- Parses human readable number
  335. -- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier,
  336. -- e.g. `10mb` equal to `10 * 1024 * 1024`
  337. -- @param {string} str input string
  338. -- @return {number|nil} parsed number
  339. --]]
  340. local function dehumanize_number(str)
  341. local function parse_suffix(s)
  342. if s == 'k' then
  343. return 1000
  344. elseif s == 'm' then
  345. return 1000000
  346. elseif s == 'g' then
  347. return 1e9
  348. elseif s == 'kb' then
  349. return 1024
  350. elseif s == 'mb' then
  351. return 1024 * 1024
  352. elseif s == 'gb' then
  353. return 1024 * 1024;
  354. end
  355. end
  356. local digit = lpeg.R("09")
  357. local parser = {}
  358. parser.integer =
  359. (lpeg.S("+-") ^ -1) *
  360. (digit ^ 1)
  361. parser.fractional =
  362. (lpeg.P(".") ) *
  363. (digit ^ 1)
  364. parser.number =
  365. (parser.integer *
  366. (parser.fractional ^ -1)) +
  367. (lpeg.S("+-") * parser.fractional)
  368. parser.humanized_number = lpeg.Cf(lpeg.Cc(1) *
  369. (parser.number / tonumber) *
  370. (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1),
  371. function (acc, val) return acc * val end)
  372. local t = lpeg.match(parser.humanized_number, str)
  373. return t
  374. end
  375. exports.dehumanize_number = dehumanize_number
  376. --[[[
  377. -- @function lua_util.table_cmp(t1, t2)
  378. -- Compare two tables deeply
  379. --]]
  380. local function table_cmp(table1, table2)
  381. local avoid_loops = {}
  382. local function recurse(t1, t2)
  383. if type(t1) ~= type(t2) then return false end
  384. if type(t1) ~= "table" then return t1 == t2 end
  385. if avoid_loops[t1] then return avoid_loops[t1] == t2 end
  386. avoid_loops[t1] = t2
  387. -- Copy keys from t2
  388. local t2keys = {}
  389. local t2tablekeys = {}
  390. for k, _ in pairs(t2) do
  391. if type(k) == "table" then table.insert(t2tablekeys, k) end
  392. t2keys[k] = true
  393. end
  394. -- Let's iterate keys from t1
  395. for k1, v1 in pairs(t1) do
  396. local v2 = t2[k1]
  397. if type(k1) == "table" then
  398. -- if key is a table, we need to find an equivalent one.
  399. local ok = false
  400. for i, tk in ipairs(t2tablekeys) do
  401. if table_cmp(k1, tk) and recurse(v1, t2[tk]) then
  402. table.remove(t2tablekeys, i)
  403. t2keys[tk] = nil
  404. ok = true
  405. break
  406. end
  407. end
  408. if not ok then return false end
  409. else
  410. -- t1 has a key which t2 doesn't have, fail.
  411. if v2 == nil then return false end
  412. t2keys[k1] = nil
  413. if not recurse(v1, v2) then return false end
  414. end
  415. end
  416. -- if t2 has a key which t1 doesn't have, fail.
  417. if next(t2keys) then return false end
  418. return true
  419. end
  420. return recurse(table1, table2)
  421. end
  422. exports.table_cmp = table_cmp
  423. --[[[
  424. -- @function lua_util.table_cmp(task, name, value, stop_chars)
  425. -- Performs header folding
  426. --]]
  427. exports.fold_header = function(task, name, value, stop_chars)
  428. local how
  429. if task:has_flag("milter") then
  430. how = "lf"
  431. else
  432. how = task:get_newlines_type()
  433. end
  434. return rspamd_util.fold_header(name, value, how, stop_chars)
  435. end
  436. --[[[
  437. -- @function lua_util.override_defaults(defaults, override)
  438. -- Overrides values from defaults with override
  439. --]]
  440. local function override_defaults(def, override)
  441. -- Corner cases
  442. if not override or type(override) ~= 'table' then
  443. return def
  444. end
  445. if not def or type(def) ~= 'table' then
  446. return override
  447. end
  448. local res = {}
  449. for k,v in pairs(override) do
  450. if type(v) == 'table' then
  451. if def[k] and type(def[k]) == 'table' then
  452. -- Recursively override elements
  453. res[k] = override_defaults(def[k], v)
  454. else
  455. res[k] = v
  456. end
  457. else
  458. res[k] = v
  459. end
  460. end
  461. for k,v in pairs(def) do
  462. if type(res[k]) == 'nil' then
  463. res[k] = v
  464. end
  465. end
  466. return res
  467. end
  468. exports.override_defaults = override_defaults
  469. --[[[
  470. -- @function lua_util.extract_specific_urls(params)
  471. -- params: {
  472. - - task
  473. - - limit <int> (default = 9999)
  474. - - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
  475. works only if number of unique eSLD less than `limit`
  476. - - need_emails <bool> (default = false)
  477. - - filter <callback> (default = nil)
  478. - - prefix <string> cache prefix (default = nil)
  479. -- }
  480. -- Apply heuristic in extracting of urls from task, this function
  481. -- tries its best to extract specific number of urls from a task based on
  482. -- their characteristics
  483. --]]
  484. -- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
  485. exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix)
  486. local default_params = {
  487. limit = 9999,
  488. esld_limit = 9999,
  489. need_emails = false,
  490. filter = nil,
  491. prefix = nil
  492. }
  493. local params
  494. if type(params_or_task) == 'table' and type(lim) == 'nil' then
  495. params = params_or_task
  496. else
  497. -- Deprecated call
  498. params = {
  499. task = params_or_task,
  500. limit = lim,
  501. need_emails = need_emails,
  502. filter = filter,
  503. prefix = prefix
  504. }
  505. end
  506. for k,v in pairs(default_params) do
  507. if not params[k] then params[k] = v end
  508. end
  509. local cache_key
  510. if params.prefix then
  511. cache_key = params.prefix
  512. else
  513. cache_key = string.format('sp_urls_%d%s', params.limit, params.need_emails)
  514. end
  515. local cached = params.task:cache_get(cache_key)
  516. if cached then
  517. return cached
  518. end
  519. local urls = params.task:get_urls(params.need_emails)
  520. if not urls then return {} end
  521. if params.filter then urls = fun.totable(fun.filter(params.filter, urls)) end
  522. if #urls <= params.limit and #urls <= params.esld_limit then
  523. params.task:cache_set(cache_key, urls)
  524. return urls
  525. end
  526. -- Filter by tld:
  527. local tlds = {}
  528. local eslds = {}
  529. local ntlds, neslds = 0, 0
  530. local res = {}
  531. for _,u in ipairs(urls) do
  532. local esld = u:get_tld()
  533. if esld then
  534. if not eslds[esld] then
  535. eslds[esld] = {u}
  536. neslds = neslds + 1
  537. else
  538. if #eslds[esld] < params.esld_limit then
  539. table.insert(eslds[esld], u)
  540. end
  541. end
  542. local parts = rspamd_str_split(esld, '.')
  543. local tld = table.concat(fun.totable(fun.tail(parts)), '.')
  544. if not tlds[tld] then
  545. tlds[tld] = {u}
  546. ntlds = ntlds + 1
  547. else
  548. table.insert(tlds[tld], u)
  549. end
  550. -- Extract priority urls that are proven to be malicious
  551. if not u:is_html_displayed() then
  552. if u:is_obscured() then
  553. table.insert(res, u)
  554. else
  555. if u:get_user() then
  556. table.insert(res, u)
  557. elseif u:is_subject() or u:is_phished() then
  558. table.insert(res, u)
  559. end
  560. end
  561. end
  562. end
  563. end
  564. local limit = params.limit
  565. limit = limit - #res
  566. if limit <= 0 then limit = 1 end
  567. if neslds <= limit then
  568. -- We can get urls based on their eslds
  569. repeat
  570. local item_found = false
  571. for _,lurls in pairs(eslds) do
  572. if #lurls > 0 then
  573. table.insert(res, table.remove(lurls))
  574. limit = limit - 1
  575. item_found = true
  576. end
  577. end
  578. until limit <= 0 or not item_found
  579. params.task:cache_set(cache_key, urls)
  580. return res
  581. end
  582. if ntlds <= limit then
  583. while limit > 0 do
  584. for _,lurls in pairs(tlds) do
  585. if #lurls > 0 then
  586. table.insert(res, table.remove(lurls))
  587. limit = limit - 1
  588. end
  589. end
  590. end
  591. params.task:cache_set(cache_key, urls)
  592. return res
  593. end
  594. -- We need to sort tlds table first
  595. local tlds_keys = {}
  596. for k,_ in pairs(tlds) do table.insert(tlds_keys, k) end
  597. table.sort(tlds_keys, function (t1, t2)
  598. return #tlds[t1] < #tlds[t2]
  599. end)
  600. ntlds = #tlds_keys
  601. for i=1,ntlds / 2 do
  602. local tld1 = tlds[tlds_keys[i]]
  603. local tld2 = tlds[tlds_keys[ntlds - i]]
  604. if #tld1 > 0 then
  605. table.insert(res, table.remove(tld1))
  606. limit = limit - 1
  607. end
  608. if #tld2 > 0 then
  609. table.insert(res, table.remove(tld2))
  610. limit = limit - 1
  611. end
  612. if limit <= 0 then
  613. break
  614. end
  615. end
  616. params.task:cache_set(cache_key, urls)
  617. return res
  618. end
  619. --[[[
  620. -- @function lua_util.deepcopy(table)
  621. -- params: {
  622. - - table
  623. -- }
  624. -- Performs deep copy of the table. Including metatables
  625. --]]
  626. local function deepcopy(orig)
  627. local orig_type = type(orig)
  628. local copy
  629. if orig_type == 'table' then
  630. copy = {}
  631. for orig_key, orig_value in next, orig, nil do
  632. copy[deepcopy(orig_key)] = deepcopy(orig_value)
  633. end
  634. setmetatable(copy, deepcopy(getmetatable(orig)))
  635. else -- number, string, boolean, etc
  636. copy = orig
  637. end
  638. return copy
  639. end
  640. exports.deepcopy = deepcopy
  641. --[[[
  642. -- @function lua_util.shallowcopy(tbl)
  643. -- Performs shallow (and fast) copy of a table or another Lua type
  644. --]]
  645. exports.shallowcopy = function(orig)
  646. local orig_type = type(orig)
  647. local copy
  648. if orig_type == 'table' then
  649. copy = {}
  650. for orig_key, orig_value in pairs(orig) do
  651. copy[orig_key] = orig_value
  652. end
  653. else
  654. copy = orig
  655. end
  656. return copy
  657. end
  658. -- Debugging support
  659. local unconditional_debug = false
  660. local debug_modules = {}
  661. local debug_aliases = {}
  662. local log_level = 384 -- debug + forced (1 << 7 | 1 << 8)
  663. if type(rspamd_config) == 'userdata' then
  664. local logger = require "rspamd_logger"
  665. -- Fill debug modules from the config
  666. local logging = rspamd_config:get_all_opt('logging')
  667. if logging then
  668. local log_level_str = logging.level
  669. if log_level_str then
  670. if log_level_str == 'debug' then
  671. unconditional_debug = true
  672. end
  673. end
  674. if not unconditional_debug then
  675. if logging.debug_modules then
  676. for _,m in ipairs(logging.debug_modules) do
  677. debug_modules[m] = true
  678. logger.infox(rspamd_config, 'enable debug for Lua module %s', m)
  679. end
  680. end
  681. if #debug_aliases > 0 then
  682. for alias,mod in pairs(debug_aliases) do
  683. if debug_modules[mod] then
  684. debug_modules[alias] = true
  685. logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)',
  686. alias, mod)
  687. end
  688. end
  689. end
  690. end
  691. end
  692. end
  693. --[[[
  694. -- @function lua_util.debugm(module, [log_object], format, ...)
  695. -- Performs fast debug log for a specific module
  696. --]]
  697. exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...)
  698. local logger = require "rspamd_logger"
  699. if unconditional_debug or debug_modules[mod] then
  700. if type(obj_or_fmt) == 'string' then
  701. logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...)
  702. else
  703. logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...)
  704. end
  705. end
  706. end
  707. --[[[
  708. -- @function lua_util.add_debug_alias(mod, alias)
  709. -- Add debugging alias so logging to `alias` will be treated as logging to `mod`
  710. --]]
  711. exports.add_debug_alias = function(mod, alias)
  712. local logger = require "rspamd_logger"
  713. debug_aliases[alias] = mod
  714. if debug_modules[mod] then
  715. debug_modules[alias] = true
  716. logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)',
  717. alias, mod)
  718. end
  719. end
  720. ---[[[
  721. -- @function lua_util.get_task_verdict(task)
  722. -- Returns verdict for a task, must be called from idempotent filters only
  723. -- Returns string:
  724. -- * `spam`: if message have over reject threshold and has more than one positive rule
  725. -- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules
  726. -- * `passthrough`: if a message has been passed through some short-circuit rule
  727. -- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score
  728. -- * `uncertain`: all other cases
  729. --]]
  730. exports.get_task_verdict = function(task)
  731. local result = task:get_metric_result()
  732. if result then
  733. if result.passthrough then
  734. return 'passthrough'
  735. end
  736. local action = result.action
  737. if action == 'reject' and result.npositive > 1 then
  738. return 'spam'
  739. elseif action == 'no action' then
  740. if result.score < 0 or result.nnegative > 3 then
  741. return 'ham'
  742. end
  743. else
  744. -- All colors of junk
  745. if action == 'add header' or action == 'rewrite subject' then
  746. if result.npositive > 2 then
  747. return 'junk'
  748. end
  749. end
  750. end
  751. end
  752. return 'uncertain'
  753. end
  754. return exports