You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_util.lua 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_util
  15. -- This module contains utility functions for working with Lua and/or Rspamd
  16. --]]
  17. local exports = {}
  18. local lpeg = require 'lpeg'
  19. local rspamd_util = require "rspamd_util"
  20. local fun = require "fun"
  21. local split_grammar = {}
  22. local spaces_split_grammar
  23. local space = lpeg.S' \t\n\v\f\r'
  24. local nospace = 1 - space
  25. local ptrim = space^0 * lpeg.C((space^0 * nospace^1)^0)
  26. local match = lpeg.match
  27. local function rspamd_str_split(s, sep)
  28. local gr
  29. if not sep then
  30. if not spaces_split_grammar then
  31. local _sep = space
  32. local elem = lpeg.C((1 - _sep)^0)
  33. local p = lpeg.Ct(elem * (_sep * elem)^0)
  34. spaces_split_grammar = p
  35. end
  36. gr = spaces_split_grammar
  37. else
  38. gr = split_grammar[sep]
  39. if not gr then
  40. local _sep = lpeg.P(sep)
  41. local elem = lpeg.C((1 - _sep)^0)
  42. local p = lpeg.Ct(elem * (_sep * elem)^0)
  43. gr = p
  44. split_grammar[sep] = gr
  45. end
  46. end
  47. return gr:match(s)
  48. end
  49. --[[[
  50. -- @function lua_util.str_split(text, deliminator)
  51. -- Splits text into a numeric table by deliminator
  52. -- @param {string} text deliminated text
  53. -- @param {string} deliminator the deliminator
  54. -- @return {table} numeric table containing string parts
  55. --]]
  56. exports.rspamd_str_split = rspamd_str_split
  57. exports.str_split = rspamd_str_split
  58. exports.rspamd_str_trim = function(s)
  59. return match(ptrim, s)
  60. end
  61. --[[[
  62. -- @function lua_util.round(number, decimalPlaces)
  63. -- Round number to fixed number of decimal points
  64. -- @param {number} number number to round
  65. -- @param {number} decimalPlaces number of decimal points
  66. -- @return {number} rounded number
  67. --]]
  68. -- Robert Jay Gould http://lua-users.org/wiki/SimpleRound
  69. exports.round = function(num, numDecimalPlaces)
  70. local mult = 10^(numDecimalPlaces or 0)
  71. return math.floor(num * mult) / mult
  72. end
  73. --[[[
  74. -- @function lua_util.template(text, replacements)
  75. -- Replaces values in a text template
  76. -- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces.
  77. -- @param {string} text text containing variables
  78. -- @param {table} replacements key/value pairs for replacements
  79. -- @return {string} string containing replaced values
  80. -- @example
  81. -- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  82. -- -- goop contains "HELLO LUA WORLD!"
  83. --]]
  84. exports.template = function(tmpl, keys)
  85. local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
  86. local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
  87. local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
  88. local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
  89. return lpeg.match(template_grammar, tmpl)
  90. end
  91. exports.remove_email_aliases = function(email_addr)
  92. local function check_gmail_user(addr)
  93. -- Remove all points
  94. local no_dots_user = string.gsub(addr.user, '%.', '')
  95. local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$')
  96. if cap then
  97. return cap, rspamd_str_split(pluses, '+'), nil
  98. elseif no_dots_user ~= addr.user then
  99. return no_dots_user,{},nil
  100. end
  101. return nil
  102. end
  103. local function check_address(addr)
  104. if addr.user then
  105. local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$')
  106. if cap then
  107. return cap, rspamd_str_split(pluses, '+'), nil
  108. end
  109. end
  110. return nil
  111. end
  112. local function set_addr(addr, new_user, new_domain)
  113. if new_user then
  114. addr.user = new_user
  115. end
  116. if new_domain then
  117. addr.domain = new_domain
  118. end
  119. if addr.domain then
  120. addr.addr = string.format('%s@%s', addr.user, addr.domain)
  121. else
  122. addr.addr = string.format('%s@', addr.user)
  123. end
  124. if addr.name and #addr.name > 0 then
  125. addr.raw = string.format('"%s" <%s>', addr.name, addr.addr)
  126. else
  127. addr.raw = string.format('<%s>', addr.addr)
  128. end
  129. end
  130. local function check_gmail(addr)
  131. local nu, tags, nd = check_gmail_user(addr)
  132. if nu then
  133. return nu, tags, nd
  134. end
  135. return nil
  136. end
  137. local function check_googlemail(addr)
  138. local nd = 'gmail.com'
  139. local nu, tags = check_gmail_user(addr)
  140. if nu then
  141. return nu, tags, nd
  142. end
  143. return nil, nil, nd
  144. end
  145. local specific_domains = {
  146. ['gmail.com'] = check_gmail,
  147. ['googlemail.com'] = check_googlemail,
  148. }
  149. if email_addr then
  150. if email_addr.domain and specific_domains[email_addr.domain] then
  151. local nu, tags, nd = specific_domains[email_addr.domain](email_addr)
  152. if nu or nd then
  153. set_addr(email_addr, nu, nd)
  154. return nu, tags
  155. end
  156. else
  157. local nu, tags, nd = check_address(email_addr)
  158. if nu or nd then
  159. set_addr(email_addr, nu, nd)
  160. return nu, tags
  161. end
  162. end
  163. return nil
  164. end
  165. end
  166. exports.is_rspamc_or_controller = function(task)
  167. local ua = task:get_request_header('User-Agent') or ''
  168. local pwd = task:get_request_header('Password')
  169. local is_rspamc = false
  170. if tostring(ua) == 'rspamc' or pwd then is_rspamc = true end
  171. return is_rspamc
  172. end
  173. --[[[
  174. -- @function lua_util.unpack(table)
  175. -- Converts numeric table to varargs
  176. -- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3
  177. -- @param {table} table numerically indexed table to unpack
  178. -- @return {varargs} unpacked table elements
  179. --]]
  180. local unpack_function = table.unpack or unpack
  181. exports.unpack = function(t)
  182. return unpack_function(t)
  183. end
  184. --[[[
  185. -- @function lua_util.spairs(table)
  186. -- Like `pairs` but keys are sorted lexicographically
  187. -- @param {table} table table containing key/value pairs
  188. -- @return {function} generator function returning key/value pairs
  189. --]]
  190. -- Sorted iteration:
  191. -- for k,v in spairs(t) do ... end
  192. --
  193. -- or with custom comparison:
  194. -- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end)
  195. --
  196. -- optional limit is also available (e.g. return top X elements)
  197. local function spairs(t, order, lim)
  198. -- collect the keys
  199. local keys = {}
  200. for k in pairs(t) do keys[#keys+1] = k end
  201. -- if order function given, sort by it by passing the table and keys a, b,
  202. -- otherwise just sort the keys
  203. if order then
  204. table.sort(keys, function(a,b) return order(t, a, b) end)
  205. else
  206. table.sort(keys)
  207. end
  208. -- return the iterator function
  209. local i = 0
  210. return function()
  211. i = i + 1
  212. if not lim or i <= lim then
  213. if keys[i] then
  214. return keys[i], t[keys[i]]
  215. end
  216. end
  217. end
  218. end
  219. exports.spairs = spairs
  220. --[[[
  221. -- @function lua_util.disable_module(modname, how)
  222. -- Disables a plugin
  223. -- @param {string} modname name of plugin to disable
  224. -- @param {string} how 'redis' to disable redis, 'config' to disable startup
  225. --]]
  226. local function disable_module(modname, how)
  227. if rspamd_plugins_state.enabled[modname] then
  228. rspamd_plugins_state.enabled[modname] = nil
  229. end
  230. if how == 'redis' then
  231. rspamd_plugins_state.disabled_redis[modname] = {}
  232. elseif how == 'config' then
  233. rspamd_plugins_state.disabled_unconfigured[modname] = {}
  234. elseif how == 'experimental' then
  235. rspamd_plugins_state.disabled_experimental[modname] = {}
  236. else
  237. rspamd_plugins_state.disabled_failed[modname] = {}
  238. end
  239. end
  240. exports.disable_module = disable_module
  241. --[[[
  242. -- @function lua_util.disable_module(modname)
  243. -- Checks experimental plugins state and disable if needed
  244. -- @param {string} modname name of plugin to check
  245. -- @return {boolean} true if plugin should be enabled, false otherwise
  246. --]]
  247. local function check_experimental(modname)
  248. if rspamd_config:experimental_enabled() then
  249. return true
  250. else
  251. disable_module(modname, 'experimental')
  252. end
  253. return false
  254. end
  255. exports.check_experimental = check_experimental
  256. --[[[
  257. -- @function lua_util.list_to_hash(list)
  258. -- Converts numerically-indexed table to table indexed by values
  259. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  260. -- @return {table} table indexed by values
  261. -- @example
  262. -- local h = lua_util.list_to_hash({"a", "b"})
  263. -- -- h contains {a = true, b = true}
  264. --]]
  265. local function list_to_hash(list)
  266. if type(list) == 'table' then
  267. if list[1] then
  268. local h = {}
  269. for _, e in ipairs(list) do
  270. h[e] = true
  271. end
  272. return h
  273. else
  274. return list
  275. end
  276. elseif type(list) == 'string' then
  277. local h = {}
  278. h[list] = true
  279. return h
  280. end
  281. end
  282. exports.list_to_hash = list_to_hash
  283. --[[[
  284. -- @function lua_util.parse_time_interval(str)
  285. -- Parses human readable time interval
  286. -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days,
  287. -- 'w' for weeks, 'y' for years
  288. -- @param {string} str input string
  289. -- @return {number|nil} parsed interval as seconds (might be fractional)
  290. --]]
  291. local function parse_time_interval(str)
  292. local function parse_time_suffix(s)
  293. if s == 's' then
  294. return 1
  295. elseif s == 'm' then
  296. return 60
  297. elseif s == 'h' then
  298. return 3600
  299. elseif s == 'd' then
  300. return 86400
  301. elseif s == 'w' then
  302. return 86400 * 7
  303. elseif s == 'y' then
  304. return 365 * 86400;
  305. end
  306. end
  307. local digit = lpeg.R("09")
  308. local parser = {}
  309. parser.integer =
  310. (lpeg.S("+-") ^ -1) *
  311. (digit ^ 1)
  312. parser.fractional =
  313. (lpeg.P(".") ) *
  314. (digit ^ 1)
  315. parser.number =
  316. (parser.integer *
  317. (parser.fractional ^ -1)) +
  318. (lpeg.S("+-") * parser.fractional)
  319. parser.time = lpeg.Cf(lpeg.Cc(1) *
  320. (parser.number / tonumber) *
  321. ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1),
  322. function (acc, val) return acc * val end)
  323. local t = lpeg.match(parser.time, str)
  324. return t
  325. end
  326. exports.parse_time_interval = parse_time_interval
  327. --[[[
  328. -- @function lua_util.table_cmp(t1, t2)
  329. -- Compare two tables deeply
  330. --]]
  331. local function table_cmp(table1, table2)
  332. local avoid_loops = {}
  333. local function recurse(t1, t2)
  334. if type(t1) ~= type(t2) then return false end
  335. if type(t1) ~= "table" then return t1 == t2 end
  336. if avoid_loops[t1] then return avoid_loops[t1] == t2 end
  337. avoid_loops[t1] = t2
  338. -- Copy keys from t2
  339. local t2keys = {}
  340. local t2tablekeys = {}
  341. for k, _ in pairs(t2) do
  342. if type(k) == "table" then table.insert(t2tablekeys, k) end
  343. t2keys[k] = true
  344. end
  345. -- Let's iterate keys from t1
  346. for k1, v1 in pairs(t1) do
  347. local v2 = t2[k1]
  348. if type(k1) == "table" then
  349. -- if key is a table, we need to find an equivalent one.
  350. local ok = false
  351. for i, tk in ipairs(t2tablekeys) do
  352. if table_cmp(k1, tk) and recurse(v1, t2[tk]) then
  353. table.remove(t2tablekeys, i)
  354. t2keys[tk] = nil
  355. ok = true
  356. break
  357. end
  358. end
  359. if not ok then return false end
  360. else
  361. -- t1 has a key which t2 doesn't have, fail.
  362. if v2 == nil then return false end
  363. t2keys[k1] = nil
  364. if not recurse(v1, v2) then return false end
  365. end
  366. end
  367. -- if t2 has a key which t1 doesn't have, fail.
  368. if next(t2keys) then return false end
  369. return true
  370. end
  371. return recurse(table1, table2)
  372. end
  373. exports.table_cmp = table_cmp
  374. --[[[
  375. -- @function lua_util.table_cmp(task, name, value, stop_chars)
  376. -- Performs header folding
  377. --]]
  378. exports.fold_header = function(task, name, value, stop_chars)
  379. local how
  380. if task:has_flag("milter") then
  381. how = "lf"
  382. else
  383. how = task:get_newlines_type()
  384. end
  385. return rspamd_util.fold_header(name, value, how, stop_chars)
  386. end
  387. --[[[
  388. -- @function lua_util.override_defaults(defaults, override)
  389. -- Overrides values from defaults with override
  390. --]]
  391. local function override_defaults(def, override)
  392. -- Corner cases
  393. if not override or type(override) ~= 'table' then
  394. return def
  395. end
  396. if not def or type(def) ~= 'table' then
  397. return override
  398. end
  399. local res = {}
  400. fun.each(function(k, v)
  401. if type(v) == 'table' then
  402. if def[k] and type(def[k]) == 'table' then
  403. -- Recursively override elements
  404. res[k] = override_defaults(def[k], v)
  405. else
  406. res[k] = v
  407. end
  408. else
  409. res[k] = v
  410. end
  411. end, override)
  412. fun.each(function(k, v)
  413. if not res[k] then
  414. res[k] = v
  415. end
  416. end, def)
  417. return res
  418. end
  419. exports.override_defaults = override_defaults
  420. --[[[
  421. -- @function lua_util.extract_specific_urls(params)
  422. -- params: {
  423. - - task
  424. - - limit <int> (default = 9999)
  425. - - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
  426. works only if number of unique eSLD less than `limit`
  427. - - need_emails <bool> (default = false)
  428. - - filter <callback> (default = nil)
  429. - - prefix <string> cache prefix (default = nil)
  430. -- }
  431. -- Apply heuristic in extracting of urls from task, this function
  432. -- tries its best to extract specific number of urls from a task based on
  433. -- their characteristics
  434. --]]
  435. -- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
  436. exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix)
  437. local default_params = {
  438. limit = 9999,
  439. esld_limit = 9999,
  440. need_emails = false,
  441. filter = nil,
  442. prefix = nil
  443. }
  444. local params
  445. if type(params_or_task) == 'table' and type(lim) == 'nil' then
  446. params = params_or_task
  447. else
  448. -- Deprecated call
  449. params = {
  450. task = params_or_task,
  451. limit = lim,
  452. need_emails = need_emails,
  453. filter = filter,
  454. prefix = prefix
  455. }
  456. end
  457. for k,v in pairs(default_params) do
  458. if not params[k] then params[k] = v end
  459. end
  460. local cache_key
  461. if params.prefix then
  462. cache_key = params.prefix
  463. else
  464. cache_key = string.format('sp_urls_%d%s', params.limit, params.need_emails)
  465. end
  466. local cached = params.task:cache_get(cache_key)
  467. if cached then
  468. return cached
  469. end
  470. local urls = params.task:get_urls(params.need_emails)
  471. if not urls then return {} end
  472. if params.filter then urls = fun.totable(fun.filter(params.filter, urls)) end
  473. if #urls <= params.limit and #urls <= params.esld_limit then
  474. params.task:cache_set(cache_key, urls)
  475. return urls
  476. end
  477. -- Filter by tld:
  478. local tlds = {}
  479. local eslds = {}
  480. local ntlds, neslds = 0, 0
  481. local res = {}
  482. for _,u in ipairs(urls) do
  483. local esld = u:get_tld()
  484. if esld then
  485. if not eslds[esld] then
  486. eslds[esld] = {u}
  487. neslds = neslds + 1
  488. else
  489. if #eslds[esld] < params.esld_limit then
  490. table.insert(eslds[esld], u)
  491. end
  492. end
  493. local parts = rspamd_str_split(esld, '.')
  494. local tld = table.concat(fun.totable(fun.tail(parts)), '.')
  495. if not tlds[tld] then
  496. tlds[tld] = {u}
  497. ntlds = ntlds + 1
  498. else
  499. table.insert(tlds[tld], u)
  500. end
  501. -- Extract priority urls that are proven to be malicious
  502. if not u:is_html_displayed() then
  503. if u:is_obscured() then
  504. table.insert(res, u)
  505. else
  506. if u:get_user() then
  507. table.insert(res, u)
  508. elseif u:is_subject() or u:is_phished() then
  509. table.insert(res, u)
  510. end
  511. end
  512. end
  513. end
  514. end
  515. local limit = params.limit
  516. limit = limit - #res
  517. if limit <= 0 then limit = 1 end
  518. if neslds <= limit then
  519. -- We can get urls based on their eslds
  520. repeat
  521. local item_found = false
  522. for _,lurls in pairs(eslds) do
  523. if #lurls > 0 then
  524. table.insert(res, table.remove(lurls))
  525. limit = limit - 1
  526. item_found = true
  527. end
  528. end
  529. until limit <= 0 or not item_found
  530. params.task:cache_set(cache_key, urls)
  531. return res
  532. end
  533. if ntlds <= limit then
  534. while limit > 0 do
  535. for _,lurls in pairs(tlds) do
  536. if #lurls > 0 then
  537. table.insert(res, table.remove(lurls))
  538. limit = limit - 1
  539. end
  540. end
  541. end
  542. params.task:cache_set(cache_key, urls)
  543. return res
  544. end
  545. -- We need to sort tlds table first
  546. local tlds_keys = {}
  547. for k,_ in pairs(tlds) do table.insert(tlds_keys, k) end
  548. table.sort(tlds_keys, function (t1, t2)
  549. return #tlds[t1] < #tlds[t2]
  550. end)
  551. ntlds = #tlds_keys
  552. for i=1,ntlds / 2 do
  553. local tld1 = tlds[tlds_keys[i]]
  554. local tld2 = tlds[tlds_keys[ntlds - i]]
  555. table.insert(res, table.remove(tld1))
  556. table.insert(res, table.remove(tld2))
  557. limit = limit - 2
  558. if limit <= 0 then
  559. break
  560. end
  561. end
  562. params.task:cache_set(cache_key, urls)
  563. return res
  564. end
  565. --[[[
  566. -- @function lua_util.deepcopy(table)
  567. -- params: {
  568. - - table
  569. -- }
  570. -- Performs deep copy of the table. Including metatables
  571. --]]
  572. local function deepcopy(orig)
  573. local orig_type = type(orig)
  574. local copy
  575. if orig_type == 'table' then
  576. copy = {}
  577. for orig_key, orig_value in next, orig, nil do
  578. copy[deepcopy(orig_key)] = deepcopy(orig_value)
  579. end
  580. setmetatable(copy, deepcopy(getmetatable(orig)))
  581. else -- number, string, boolean, etc
  582. copy = orig
  583. end
  584. return copy
  585. end
  586. exports.deepcopy = deepcopy
  587. --[[[
  588. -- @function lua_util.shallowcopy(tbl)
  589. -- Performs shallow (and fast) copy of a table or another Lua type
  590. --]]
  591. exports.shallowcopy = function(orig)
  592. local orig_type = type(orig)
  593. local copy
  594. if orig_type == 'table' then
  595. copy = {}
  596. for orig_key, orig_value in pairs(orig) do
  597. copy[orig_key] = orig_value
  598. end
  599. else
  600. copy = orig
  601. end
  602. return copy
  603. end
  604. -- Debugging support
  605. local unconditional_debug = false
  606. local debug_modules = {}
  607. local log_level = 384 -- debug + forced (1 << 7 | 1 << 8)
  608. if type(rspamd_config) == 'userdata' then
  609. local logger = require "rspamd_logger"
  610. -- Fill debug modules from the config
  611. local logging = rspamd_config:get_all_opt('logging')
  612. if logging then
  613. local log_level_str = logging.level
  614. if log_level_str then
  615. if log_level_str == 'debug' then
  616. unconditional_debug = true
  617. end
  618. end
  619. if not unconditional_debug and logging.debug_modules then
  620. for _,m in ipairs(logging.debug_modules) do
  621. debug_modules[m] = true
  622. logger.infox(rspamd_config, 'enable debug for Lua module %s', m)
  623. end
  624. end
  625. end
  626. end
  627. exports.debugm = function(mod, ...)
  628. local logger = require "rspamd_logger"
  629. if unconditional_debug or debug_modules[mod] then
  630. logger.logx(log_level, mod, ...)
  631. end
  632. end
  633. return exports