You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_util.lua 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. --[[
  2. Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_util
  15. -- This module contains utility functions for working with Lua and/or Rspamd
  16. --]]
  17. local exports = {}
  18. local lpeg = require 'lpeg'
  19. local rspamd_util = require "rspamd_util"
  20. local fun = require "fun"
  21. local split_grammar = {}
  22. local function rspamd_str_split(s, sep)
  23. local gr = split_grammar[sep]
  24. if not gr then
  25. local _sep = lpeg.P(sep)
  26. local elem = lpeg.C((1 - _sep)^0)
  27. local p = lpeg.Ct(elem * (_sep * elem)^0)
  28. gr = p
  29. split_grammar[sep] = gr
  30. end
  31. return gr:match(s)
  32. end
  33. --[[[
  34. -- @function lua_util.str_split(text, deliminator)
  35. -- Splits text into a numeric table by deliminator
  36. -- @param {string} text deliminated text
  37. -- @param {string} deliminator the deliminator
  38. -- @return {table} numeric table containing string parts
  39. --]]
  40. exports.rspamd_str_split = rspamd_str_split
  41. exports.str_split = rspamd_str_split
  42. local space = lpeg.S' \t\n\v\f\r'
  43. local nospace = 1 - space
  44. local ptrim = space^0 * lpeg.C((space^0 * nospace^1)^0)
  45. local match = lpeg.match
  46. exports.rspamd_str_trim = function(s)
  47. return match(ptrim, s)
  48. end
  49. --[[[
  50. -- @function lua_util.round(number, decimalPlaces)
  51. -- Round number to fixed number of decimal points
  52. -- @param {number} number number to round
  53. -- @param {number} decimalPlaces number of decimal points
  54. -- @return {number} rounded number
  55. --]]
  56. -- Robert Jay Gould http://lua-users.org/wiki/SimpleRound
  57. exports.round = function(num, numDecimalPlaces)
  58. local mult = 10^(numDecimalPlaces or 0)
  59. return math.floor(num * mult) / mult
  60. end
  61. --[[[
  62. -- @function lua_util.template(text, replacements)
  63. -- Replaces values in a text template
  64. -- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces.
  65. -- @param {string} text text containing variables
  66. -- @param {table} replacements key/value pairs for replacements
  67. -- @return {string} string containing replaced values
  68. -- @example
  69. -- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
  70. -- -- goop contains "HELLO LUA WORLD!"
  71. --]]
  72. exports.template = function(tmpl, keys)
  73. local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
  74. local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) }
  75. local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") }
  76. local template_grammar = lpeg.Cs((var + var_braced + 1)^0)
  77. return lpeg.match(template_grammar, tmpl)
  78. end
  79. exports.remove_email_aliases = function(email_addr)
  80. local function check_gmail_user(addr)
  81. -- Remove all points
  82. local no_dots_user = string.gsub(addr.user, '%.', '')
  83. local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$')
  84. if cap then
  85. return cap, rspamd_str_split(pluses, '+'), nil
  86. elseif no_dots_user ~= addr.user then
  87. return no_dots_user,{},nil
  88. end
  89. return nil
  90. end
  91. local function check_address(addr)
  92. if addr.user then
  93. local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$')
  94. if cap then
  95. return cap, rspamd_str_split(pluses, '+'), nil
  96. end
  97. end
  98. return nil
  99. end
  100. local function set_addr(addr, new_user, new_domain)
  101. if new_user then
  102. addr.user = new_user
  103. end
  104. if new_domain then
  105. addr.domain = new_domain
  106. end
  107. if addr.domain then
  108. addr.addr = string.format('%s@%s', addr.user, addr.domain)
  109. else
  110. addr.addr = string.format('%s@', addr.user)
  111. end
  112. if addr.name and #addr.name > 0 then
  113. addr.raw = string.format('"%s" <%s>', addr.name, addr.addr)
  114. else
  115. addr.raw = string.format('<%s>', addr.addr)
  116. end
  117. end
  118. local function check_gmail(addr)
  119. local nu, tags, nd = check_gmail_user(addr)
  120. if nu then
  121. return nu, tags, nd
  122. end
  123. return nil
  124. end
  125. local function check_googlemail(addr)
  126. local nd = 'gmail.com'
  127. local nu, tags = check_gmail_user(addr)
  128. if nu then
  129. return nu, tags, nd
  130. end
  131. return nil, nil, nd
  132. end
  133. local specific_domains = {
  134. ['gmail.com'] = check_gmail,
  135. ['googlemail.com'] = check_googlemail,
  136. }
  137. if email_addr then
  138. if email_addr.domain and specific_domains[email_addr.domain] then
  139. local nu, tags, nd = specific_domains[email_addr.domain](email_addr)
  140. if nu or nd then
  141. set_addr(email_addr, nu, nd)
  142. return nu, tags
  143. end
  144. else
  145. local nu, tags, nd = check_address(email_addr)
  146. if nu or nd then
  147. set_addr(email_addr, nu, nd)
  148. return nu, tags
  149. end
  150. end
  151. return nil
  152. end
  153. end
  154. exports.is_rspamc_or_controller = function(task)
  155. local ua = task:get_request_header('User-Agent') or ''
  156. local pwd = task:get_request_header('Password')
  157. local is_rspamc = false
  158. if tostring(ua) == 'rspamc' or pwd then is_rspamc = true end
  159. return is_rspamc
  160. end
  161. --[[[
  162. -- @function lua_util.unpack(table)
  163. -- Converts numeric table to varargs
  164. -- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3
  165. -- @param {table} table numerically indexed table to unpack
  166. -- @return {varargs} unpacked table elements
  167. --]]
  168. local unpack_function = table.unpack or unpack
  169. exports.unpack = function(t)
  170. return unpack_function(t)
  171. end
  172. --[[[
  173. -- @function lua_util.spairs(table)
  174. -- Like `pairs` but keys are sorted lexicographically
  175. -- @param {table} table table containing key/value pairs
  176. -- @return {function} generator function returning key/value pairs
  177. --]]
  178. -- Sorted iteration:
  179. -- for k,v in spairs(t) do ... end
  180. --
  181. -- or with custom comparison:
  182. -- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end)
  183. --
  184. -- optional limit is also available (e.g. return top X elements)
  185. local function spairs(t, order, lim)
  186. -- collect the keys
  187. local keys = {}
  188. for k in pairs(t) do keys[#keys+1] = k end
  189. -- if order function given, sort by it by passing the table and keys a, b,
  190. -- otherwise just sort the keys
  191. if order then
  192. table.sort(keys, function(a,b) return order(t, a, b) end)
  193. else
  194. table.sort(keys)
  195. end
  196. -- return the iterator function
  197. local i = 0
  198. return function()
  199. i = i + 1
  200. if not lim or i <= lim then
  201. if keys[i] then
  202. return keys[i], t[keys[i]]
  203. end
  204. end
  205. end
  206. end
  207. exports.spairs = spairs
  208. --[[[
  209. -- @function lua_util.disable_module(modname, how)
  210. -- Disables a plugin
  211. -- @param {string} modname name of plugin to disable
  212. -- @param {string} how 'redis' to disable redis, 'config' to disable startup
  213. --]]
  214. local function disable_module(modname, how)
  215. if rspamd_plugins_state.enabled[modname] then
  216. rspamd_plugins_state.enabled[modname] = nil
  217. end
  218. if how == 'redis' then
  219. rspamd_plugins_state.disabled_redis[modname] = {}
  220. elseif how == 'config' then
  221. rspamd_plugins_state.disabled_unconfigured[modname] = {}
  222. elseif how == 'experimental' then
  223. rspamd_plugins_state.disabled_experimental[modname] = {}
  224. else
  225. rspamd_plugins_state.disabled_failed[modname] = {}
  226. end
  227. end
  228. exports.disable_module = disable_module
  229. --[[[
  230. -- @function lua_util.disable_module(modname)
  231. -- Checks experimental plugins state and disable if needed
  232. -- @param {string} modname name of plugin to check
  233. -- @return {boolean} true if plugin should be enabled, false otherwise
  234. --]]
  235. local function check_experimental(modname)
  236. if rspamd_config:experimental_enabled() then
  237. return true
  238. else
  239. disable_module(modname, 'experimental')
  240. end
  241. return false
  242. end
  243. exports.check_experimental = check_experimental
  244. --[[[
  245. -- @function lua_util.list_to_hash(list)
  246. -- Converts numerically-indexed table to table indexed by values
  247. -- @param {table} list numerically-indexed table or string, which is treated as a one-element list
  248. -- @return {table} table indexed by values
  249. -- @example
  250. -- local h = lua_util.list_to_hash({"a", "b"})
  251. -- -- h contains {a = true, b = true}
  252. --]]
  253. local function list_to_hash(list)
  254. if type(list) == 'table' then
  255. if list[1] then
  256. local h = {}
  257. for _, e in ipairs(list) do
  258. h[e] = true
  259. end
  260. return h
  261. else
  262. return list
  263. end
  264. elseif type(list) == 'string' then
  265. local h = {}
  266. h[list] = true
  267. return h
  268. end
  269. end
  270. exports.list_to_hash = list_to_hash
  271. --[[[
  272. -- @function lua_util.parse_time_interval(str)
  273. -- Parses human readable time interval
  274. -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days,
  275. -- 'w' for weeks, 'y' for years
  276. -- @param {string} str input string
  277. -- @return {number|nil} parsed interval as seconds (might be fractional)
  278. --]]
  279. local function parse_time_interval(str)
  280. local function parse_time_suffix(s)
  281. if s == 's' then
  282. return 1
  283. elseif s == 'm' then
  284. return 60
  285. elseif s == 'h' then
  286. return 3600
  287. elseif s == 'd' then
  288. return 86400
  289. elseif s == 'w' then
  290. return 86400 * 7
  291. elseif s == 'y' then
  292. return 365 * 86400;
  293. end
  294. end
  295. local digit = lpeg.R("09")
  296. local parser = {}
  297. parser.integer =
  298. (lpeg.S("+-") ^ -1) *
  299. (digit ^ 1)
  300. parser.fractional =
  301. (lpeg.P(".") ) *
  302. (digit ^ 1)
  303. parser.number =
  304. (parser.integer *
  305. (parser.fractional ^ -1)) +
  306. (lpeg.S("+-") * parser.fractional)
  307. parser.time = lpeg.Cf(lpeg.Cc(1) *
  308. (parser.number / tonumber) *
  309. ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1),
  310. function (acc, val) return acc * val end)
  311. local t = lpeg.match(parser.time, str)
  312. return t
  313. end
  314. exports.parse_time_interval = parse_time_interval
  315. --[[[
  316. -- @function lua_util.table_cmp(t1, t2)
  317. -- Compare two tables deeply
  318. --]]
  319. local function table_cmp(table1, table2)
  320. local avoid_loops = {}
  321. local function recurse(t1, t2)
  322. if type(t1) ~= type(t2) then return false end
  323. if type(t1) ~= "table" then return t1 == t2 end
  324. if avoid_loops[t1] then return avoid_loops[t1] == t2 end
  325. avoid_loops[t1] = t2
  326. -- Copy keys from t2
  327. local t2keys = {}
  328. local t2tablekeys = {}
  329. for k, _ in pairs(t2) do
  330. if type(k) == "table" then table.insert(t2tablekeys, k) end
  331. t2keys[k] = true
  332. end
  333. -- Let's iterate keys from t1
  334. for k1, v1 in pairs(t1) do
  335. local v2 = t2[k1]
  336. if type(k1) == "table" then
  337. -- if key is a table, we need to find an equivalent one.
  338. local ok = false
  339. for i, tk in ipairs(t2tablekeys) do
  340. if table_cmp(k1, tk) and recurse(v1, t2[tk]) then
  341. table.remove(t2tablekeys, i)
  342. t2keys[tk] = nil
  343. ok = true
  344. break
  345. end
  346. end
  347. if not ok then return false end
  348. else
  349. -- t1 has a key which t2 doesn't have, fail.
  350. if v2 == nil then return false end
  351. t2keys[k1] = nil
  352. if not recurse(v1, v2) then return false end
  353. end
  354. end
  355. -- if t2 has a key which t1 doesn't have, fail.
  356. if next(t2keys) then return false end
  357. return true
  358. end
  359. return recurse(table1, table2)
  360. end
  361. exports.table_cmp = table_cmp
  362. --[[[
  363. -- @function lua_util.table_cmp(task, name, value, stop_chars)
  364. -- Performs header folding
  365. --]]
  366. exports.fold_header = function(task, name, value, stop_chars)
  367. local how
  368. if task:has_flag("milter") then
  369. how = "lf"
  370. else
  371. how = task:get_newlines_type()
  372. end
  373. return rspamd_util.fold_header(name, value, how, stop_chars)
  374. end
  375. --[[[
  376. -- @function lua_util.override_defaults(defaults, override)
  377. -- Overrides values from defaults with override
  378. --]]
  379. local function override_defaults(def, override)
  380. -- Corner cases
  381. if not override or type(override) ~= 'table' then
  382. return def
  383. end
  384. if not def or type(def) ~= 'table' then
  385. return override
  386. end
  387. local res = {}
  388. fun.each(function(k, v)
  389. if type(v) == 'table' then
  390. if def[k] and type(def[k]) == 'table' then
  391. -- Recursively override elements
  392. res[k] = override_defaults(def[k], v)
  393. else
  394. res[k] = v
  395. end
  396. else
  397. res[k] = v
  398. end
  399. end, override)
  400. fun.each(function(k, v)
  401. if not res[k] then
  402. res[k] = v
  403. end
  404. end, def)
  405. return res
  406. end
  407. exports.override_defaults = override_defaults
  408. return exports