--[[ Copyright (c) 2023, Vsevolod Stakhov Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ]]-- --[[[ -- @module lua_util -- This module contains utility functions for working with Lua and/or Rspamd --]] local exports = {} local lpeg = require 'lpeg' local rspamd_util = require "rspamd_util" local fun = require "fun" local lupa = require "lupa" local split_grammar = {} local spaces_split_grammar local space = lpeg.S ' \t\n\v\f\r' local nospace = 1 - space local ptrim = space ^ 0 * lpeg.C((space ^ 0 * nospace ^ 1) ^ 0) local match = lpeg.match local function shallowcopy(orig) local orig_type = type(orig) local copy if orig_type == 'table' then copy = {} for orig_key, orig_value in pairs(orig) do copy[orig_key] = orig_value end else copy = orig end return copy end local function deepcopy(orig) local orig_type = type(orig) local copy if orig_type == 'table' then copy = {} for orig_key, orig_value in next, orig, nil do copy[deepcopy(orig_key)] = deepcopy(orig_value) end if getmetatable(orig) then setmetatable(copy, deepcopy(getmetatable(orig))) end else -- number, string, boolean, etc copy = orig end return copy end lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', { keep_trailing_newline = true, autoescape = false, }) lupa.filters.pbkdf = function(s) local cr = require "rspamd_cryptobox" return cr.pbkdf(s) end -- Dirty hacks to avoid shared state package.loaded['lupa'] = nil local lupa_orig = require "lupa" local function rspamd_str_split(s, sep) local gr if not sep then if not spaces_split_grammar then local _sep = space local elem = lpeg.C((1 - _sep) ^ 0) local p = lpeg.Ct(elem * (_sep * elem) ^ 0) spaces_split_grammar = p end gr = spaces_split_grammar else gr = split_grammar[sep] if not gr then local _sep if type(sep) == 'string' then _sep = lpeg.S(sep) -- Assume set else _sep = sep -- Assume lpeg object end local elem = lpeg.C((1 - _sep) ^ 0) local p = lpeg.Ct(elem * (_sep * elem) ^ 0) gr = p split_grammar[sep] = gr end end return gr:match(s) end --[[[ -- @function lua_util.str_split(text, delimiter) -- Splits text into a numeric table by delimiter -- @param {string} text delimited text -- @param {string} delimiter the delimiter -- @return {table} numeric table containing string parts --]] exports.rspamd_str_split = rspamd_str_split exports.str_split = rspamd_str_split local function rspamd_str_trim(s) return match(ptrim, s) end exports.rspamd_str_trim = rspamd_str_trim --[[[ -- @function lua_util.str_trim(text) -- Returns a string with no trailing and leading spaces -- @param {string} text input text -- @return {string} string with no trailing and leading spaces --]] exports.str_trim = rspamd_str_trim --[[[ -- @function lua_util.str_startswith(text, prefix) -- @param {string} text -- @param {string} prefix -- @return {boolean} true if text starts with the specified prefix, false otherwise --]] exports.str_startswith = function(s, prefix) return s:sub(1, prefix:len()) == prefix end --[[[ -- @function lua_util.str_endswith(text, suffix) -- @param {string} text -- @param {string} suffix -- @return {boolean} true if text ends with the specified suffix, false otherwise --]] exports.str_endswith = function(s, suffix) return s:find(suffix, -suffix:len(), true) ~= nil end --[[[ -- @function lua_util.round(number, decimalPlaces) -- Round number to fixed number of decimal points -- @param {number} number number to round -- @param {number} decimalPlaces number of decimal points -- @return {number} rounded number --]] -- modified version from Robert Jay Gould http://lua-users.org/wiki/SimpleRound exports.round = function(num, numDecimalPlaces) local mult = 10 ^ (numDecimalPlaces or 0) if num >= 0 then return math.floor(num * mult + 0.5) / mult else return math.ceil(num * mult - 0.5) / mult end end --[[[ -- @function lua_util.template(text, replacements) -- Replaces values in a text template -- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces. -- @param {string} text text containing variables -- @param {table} replacements key/value pairs for replacements -- @return {string} string containing replaced values -- @example -- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) -- -- goop contains "HELLO LUA WORLD!" --]] exports.template = function(tmpl, keys) local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit ^ 1) / keys) } local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit ^ 1) / keys) * (lpeg.P("}") / "") } local template_grammar = lpeg.Cs((var + var_braced + 1) ^ 0) return lpeg.match(template_grammar, tmpl) end local function enrich_template_with_globals(env) local newenv = shallowcopy(env) newenv.paths = rspamd_paths newenv.env = rspamd_env return newenv end --[[[ -- @function lua_util.jinja_template(text, env[, skip_global_env][, is_orig][, custom_filters]) -- Replaces values in a text template according to jinja2 syntax -- @param {string} text text containing variables -- @param {table} replacements key/value pairs for replacements -- @param {boolean} skip_global_env don't export Rspamd superglobals -- @param {boolean} is_orig use the original lupa configuration with `{{` for variables -- @param {table} custom_filters custom filters to use (or nil if not needed) -- @return {string} string containing replaced values -- @example -- lua_util.jinja_template("HELLO {=FOO=} {=BAR=}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) -- "HELLO LUA WORLD!" --]] exports.jinja_template = function(text, env, skip_global_env, is_orig, custom_filters) local lupa_to_use = is_orig and lupa_orig or lupa if not skip_global_env then env = enrich_template_with_globals(env) end local orig_filters = {} if type(custom_filters) == 'table' then for k, v in pairs(custom_filters) do orig_filters[k] = lupa_to_use.filters[k] lupa_to_use.filters[k] = v end end local result = lupa_to_use.expand(text, env) -- Restore custom filters if type(custom_filters) == 'table' then for k, _ in pairs(custom_filters) do lupa_to_use.filters[k] = orig_filters[k] end end return result end --[[[ -- @function lua_util.jinja_file(filename, env[, skip_global_env][, is_orig][, custom_filters]) -- Replaces values in a text template according to jinja2 syntax -- @param {string} filename name of file to expand -- @param {table} replacements key/value pairs for replacements -- @param {boolean} skip_global_env don't export Rspamd superglobals -- @param {boolean} is_orig use the original lupa configuration with `{{` for variables -- @param {table} custom_filters custom filters to use (or nil if not needed) -- @return {string} string containing replaced values -- @example -- lua_util.jinja_template("HELLO {=FOO=} {=BAR=}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) -- "HELLO LUA WORLD!" --]] exports.jinja_template_file = function(filename, env, skip_global_env, is_orig, custom_filters) local lupa_to_use = is_orig and lupa_orig or lupa if not skip_global_env then env = enrich_template_with_globals(env) end local orig_filters = {} if type(custom_filters) == 'table' then for k, v in pairs(custom_filters) do orig_filters[k] = lupa_to_use.filters[k] lupa_to_use.filters[k] = v end end local result = lupa_to_use.expand_file(filename, env) -- Restore custom filters if type(custom_filters) == 'table' then for k, _ in pairs(custom_filters) do lupa_to_use.filters[k] = orig_filters[k] end end return result end exports.remove_email_aliases = function(email_addr) local function check_gmail_user(addr) -- Remove all points local no_dots_user = string.gsub(addr.user, '%.', '') local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$') if cap then return cap, rspamd_str_split(pluses, '+'), nil elseif no_dots_user ~= addr.user then return no_dots_user, {}, nil end return nil end local function check_address(addr) if addr.user then local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$') if cap then return cap, rspamd_str_split(pluses, '+'), nil end end return nil end local function set_addr(addr, new_user, new_domain) if new_user then addr.user = new_user end if new_domain then addr.domain = new_domain end if addr.domain then addr.addr = string.format('%s@%s', addr.user, addr.domain) else addr.addr = string.format('%s@', addr.user) end if addr.name and #addr.name > 0 then addr.raw = string.format('"%s" <%s>', addr.name, addr.addr) else addr.raw = string.format('<%s>', addr.addr) end end local function check_gmail(addr) local nu, tags, nd = check_gmail_user(addr) if nu then return nu, tags, nd end return nil end local function check_googlemail(addr) local nd = 'gmail.com' local nu, tags = check_gmail_user(addr) if nu then return nu, tags, nd end return nil, nil, nd end local specific_domains = { ['gmail.com'] = check_gmail, ['googlemail.com'] = check_googlemail, } if email_addr then if email_addr.domain and specific_domains[email_addr.domain] then local nu, tags, nd = specific_domains[email_addr.domain](email_addr) if nu or nd then set_addr(email_addr, nu, nd) return nu, tags end else local nu, tags, nd = check_address(email_addr) if nu or nd then set_addr(email_addr, nu, nd) return nu, tags end end return nil end end exports.is_rspamc_or_controller = function(task) local ua = task:get_request_header('User-Agent') or '' local pwd = task:get_request_header('Password') local is_rspamc = false if tostring(ua) == 'rspamc' or pwd then is_rspamc = true end return is_rspamc end --[[[ -- @function lua_util.unpack(table) -- Converts numeric table to varargs -- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3 -- @param {table} table numerically indexed table to unpack -- @return {varargs} unpacked table elements --]] local unpack_function = table.unpack or unpack exports.unpack = function(t) return unpack_function(t) end --[[[ -- @function lua_util.flatten(table) -- Flatten underlying tables in a single table -- @param {table} table table of tables -- @return {table} flattened table --]] exports.flatten = function(t) local res = {} for _, e in fun.iter(t) do for _, v in fun.iter(e) do res[#res + 1] = v end end return res end --[[[ -- @function lua_util.spairs(table) -- Like `pairs` but keys are sorted lexicographically -- @param {table} table table containing key/value pairs -- @return {function} generator function returning key/value pairs --]] -- Sorted iteration: -- for k,v in spairs(t) do ... end -- -- or with custom comparison: -- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end) -- -- optional limit is also available (e.g. return top X elements) local function spairs(t, order, lim) -- collect the keys local keys = {} for k in pairs(t) do keys[#keys + 1] = k end -- if order function given, sort by it by passing the table and keys a, b, -- otherwise just sort the keys if order then table.sort(keys, function(a, b) return order(t, a, b) end) else table.sort(keys) end -- return the iterator function local i = 0 return function() i = i + 1 if not lim or i <= lim then if keys[i] then return keys[i], t[keys[i]] end end end end exports.spairs = spairs local lua_cfg_utils = require "lua_cfg_utils" exports.config_utils = lua_cfg_utils exports.disable_module = lua_cfg_utils.disable_module --[[[ -- @function lua_util.disable_module(modname) -- Checks experimental plugins state and disable if needed -- @param {string} modname name of plugin to check -- @return {boolean} true if plugin should be enabled, false otherwise --]] local function check_experimental(modname) if rspamd_config:experimental_enabled() then return true else lua_cfg_utils.disable_module(modname, 'experimental') end return false end exports.check_experimental = check_experimental --[[[ -- @function lua_util.list_to_hash(list) -- Converts numerically-indexed table to table indexed by values -- @param {table} list numerically-indexed table or string, which is treated as a one-element list -- @return {table} table indexed by values -- @example -- local h = lua_util.list_to_hash({"a", "b"}) -- -- h contains {a = true, b = true} --]] local function list_to_hash(list) if type(list) == 'table' then if list[1] then local h = {} for _, e in ipairs(list) do h[e] = true end return h else return list end elseif type(list) == 'string' then local h = {} h[list] = true return h end end exports.list_to_hash = list_to_hash --[[[ -- @function lua_util.nkeys(table|gen, param, state) -- Returns number of keys in a table (i.e. from both the array and hash parts combined) -- @param {table} list numerically-indexed table or string, which is treated as a one-element list -- @return {number} number of keys -- @example -- print(lua_util.nkeys({})) -- 0 -- print(lua_util.nkeys({ "a", nil, "b" })) -- 2 -- print(lua_util.nkeys({ dog = 3, cat = 4, bird = nil })) -- 2 -- print(lua_util.nkeys({ "a", dog = 3, cat = 4 })) -- 3 -- --]] local function nkeys(gen, param, state) local n = 0 if not param then for _, _ in pairs(gen) do n = n + 1 end else for _, _ in fun.iter(gen, param, state) do n = n + 1 end end return n end exports.nkeys = nkeys --[[[ -- @function lua_util.parse_time_interval(str) -- Parses human readable time interval -- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days, -- 'w' for weeks, 'y' for years -- @param {string} str input string -- @return {number|nil} parsed interval as seconds (might be fractional) --]] local function parse_time_interval(str) local function parse_time_suffix(s) if s == 's' then return 1 elseif s == 'm' then return 60 elseif s == 'h' then return 3600 elseif s == 'd' then return 86400 elseif s == 'w' then return 86400 * 7 elseif s == 'y' then return 365 * 86400; end end local digit = lpeg.R("09") local parser = {} parser.integer = (lpeg.S("+-") ^ -1) * (digit ^ 1) parser.fractional = (lpeg.P(".")) * (digit ^ 1) parser.number = (parser.integer * (parser.fractional ^ -1)) + (lpeg.S("+-") * parser.fractional) parser.time = lpeg.Cf(lpeg.Cc(1) * (parser.number / tonumber) * ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1), function(acc, val) return acc * val end) local t = lpeg.match(parser.time, str) return t end exports.parse_time_interval = parse_time_interval --[[[ -- @function lua_util.dehumanize_number(str) -- Parses human readable number -- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier, -- e.g. `10mb` equal to `10 * 1024 * 1024` -- @param {string} str input string -- @return {number|nil} parsed number --]] local function dehumanize_number(str) local function parse_suffix(s) if s == 'k' then return 1000 elseif s == 'm' then return 1000000 elseif s == 'g' then return 1e9 elseif s == 'kb' then return 1024 elseif s == 'mb' then return 1024 * 1024 elseif s == 'gb' then return 1024 * 1024; end end local digit = lpeg.R("09") local parser = {} parser.integer = (lpeg.S("+-") ^ -1) * (digit ^ 1) parser.fractional = (lpeg.P(".")) * (digit ^ 1) parser.number = (parser.integer * (parser.fractional ^ -1)) + (lpeg.S("+-") * parser.fractional) parser.humanized_number = lpeg.Cf(lpeg.Cc(1) * (parser.number / tonumber) * (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1), function(acc, val) return acc * val end) local t = lpeg.match(parser.humanized_number, str) return t end exports.dehumanize_number = dehumanize_number --[[[ -- @function lua_util.table_cmp(t1, t2) -- Compare two tables deeply --]] local function table_cmp(table1, table2) local avoid_loops = {} local function recurse(t1, t2) if type(t1) ~= type(t2) then return false end if type(t1) ~= "table" then return t1 == t2 end if avoid_loops[t1] then return avoid_loops[t1] == t2 end avoid_loops[t1] = t2 -- Copy keys from t2 local t2keys = {} local t2tablekeys = {} for k, _ in pairs(t2) do if type(k) == "table" then table.insert(t2tablekeys, k) end t2keys[k] = true end -- Let's iterate keys from t1 for k1, v1 in pairs(t1) do local v2 = t2[k1] if type(k1) == "table" then -- if key is a table, we need to find an equivalent one. local ok = false for i, tk in ipairs(t2tablekeys) do if table_cmp(k1, tk) and recurse(v1, t2[tk]) then table.remove(t2tablekeys, i) t2keys[tk] = nil ok = true break end end if not ok then return false end else -- t1 has a key which t2 doesn't have, fail. if v2 == nil then return false end t2keys[k1] = nil if not recurse(v1, v2) then return false end end end -- if t2 has a key which t1 doesn't have, fail. if next(t2keys) then return false end return true end return recurse(table1, table2) end exports.table_cmp = table_cmp --[[[ -- @function lua_util.table_merge(t1, t2) -- Merge two tables --]] local function table_merge(t1, t2) local res = {} local nidx = 1 -- for numeric indicies local it_func = function(k, v) if type(k) == 'number' then res[nidx] = v nidx = nidx + 1 else res[k] = v end end for k, v in pairs(t1) do it_func(k, v) end for k, v in pairs(t2) do it_func(k, v) end return res end exports.table_merge = table_merge --[[[ -- @function lua_util.table_cmp(task, name, value, stop_chars) -- Performs header folding --]] exports.fold_header = function(task, name, value, stop_chars) local how if task:has_flag("milter") then how = "lf" else how = task:get_newlines_type() end return rspamd_util.fold_header(name, value, how, stop_chars) end --[[[ -- @function lua_util.override_defaults(defaults, override) -- Overrides values from defaults with override --]] local function override_defaults(def, override) -- Corner cases if not override or type(override) ~= 'table' then return def end if not def or type(def) ~= 'table' then return override end local res = {} for k, v in pairs(override) do if type(v) == 'table' then if def[k] and type(def[k]) == 'table' then -- Recursively override elements res[k] = override_defaults(def[k], v) else res[k] = v end else res[k] = v end end for k, v in pairs(def) do if type(res[k]) == 'nil' then res[k] = v end end return res end exports.override_defaults = override_defaults --[[[ -- @function lua_util.filter_specific_urls(urls, params) -- params: { - - task - if needed to save in the cache - - limit (default = 9999) - - esld_limit (default = 9999) n domains per eSLD (effective second level domain) works only if number of unique eSLD less than `limit` - - need_emails (default = false) - - filter (default = nil) - - prefix cache prefix (default = nil) -- } -- Apply heuristic in extracting of urls from `urls` table, this function -- tries its best to extract specific number of urls from a task based on -- their characteristics --]] exports.filter_specific_urls = function(urls, params) local cache_key if params.task and not params.no_cache then if params.prefix then cache_key = params.prefix else cache_key = string.format('sp_urls_%d%s%s%s', params.limit, tostring(params.need_emails or false), tostring(params.need_images or false), tostring(params.need_content or false)) end local cached = params.task:cache_get(cache_key) if cached then return cached end end if not urls then return {} end if params.filter then urls = fun.totable(fun.filter(params.filter, urls)) end -- Filter by tld: local tlds = {} local eslds = {} local ntlds, neslds = 0, 0 local res = {} local nres = 0 local function insert_url(str, u) if not res[str] then res[str] = u nres = nres + 1 return true end return false end local function process_single_url(u, default_priority) local priority = default_priority or 1 -- Normal priority local flags = u:get_flags() if params.ignore_ip and flags.numeric then return end if flags.redirected then local redir = u:get_redirected() -- get the real url if params.ignore_redirected then -- Replace `u` with redir u = redir priority = 2 else -- Process both redirected url and the original one process_single_url(redir, 2) end end if flags.image then if not params.need_images then -- Ignore url return else -- Penalise images in urls priority = 0 end end local esld = u:get_tld() local str_hash = tostring(u) if esld then -- Special cases if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then if flags.obscured then priority = 3 else if (flags.has_user or flags.has_port) then priority = 2 elseif (flags.subject or flags.phished) then priority = 2 end end elseif flags.html_displayed then priority = 0 end if not eslds[esld] then eslds[esld] = { { str_hash, u, priority } } neslds = neslds + 1 else if #eslds[esld] < params.esld_limit then table.insert(eslds[esld], { str_hash, u, priority }) end end -- eSLD - 1 part => tld local parts = rspamd_str_split(esld, '.') local tld = table.concat(fun.totable(fun.tail(parts)), '.') if not tlds[tld] then tlds[tld] = { { str_hash, u, priority } } ntlds = ntlds + 1 else table.insert(tlds[tld], { str_hash, u, priority }) end end end for _, u in ipairs(urls) do process_single_url(u) end local limit = params.limit limit = limit - nres if limit < 0 then limit = 0 end if limit == 0 then res = exports.values(res) if params.task and not params.no_cache then params.task:cache_set(cache_key, res) end return res end -- Sort eSLDs and tlds local function sort_stuff(tbl) -- Sort according to max priority table.sort(tbl, function(e1, e2) -- Sort by priority so max priority is at the end table.sort(e1, function(tr1, tr2) return tr1[3] < tr2[3] end) table.sort(e2, function(tr1, tr2) return tr1[3] < tr2[3] end) if e1[#e1][3] ~= e2[#e2][3] then -- Sort by priority so max priority is at the beginning return e1[#e1][3] > e2[#e2][3] else -- Prefer less urls to more urls per esld return #e1 < #e2 end end) return tbl end eslds = sort_stuff(exports.values(eslds)) neslds = #eslds if neslds <= limit then -- Number of eslds < limit repeat local item_found = false for _, lurls in ipairs(eslds) do if #lurls > 0 then local last = table.remove(lurls) insert_url(last[1], last[2]) limit = limit - 1 item_found = true end end until limit <= 0 or not item_found res = exports.values(res) if params.task and not params.no_cache then params.task:cache_set(cache_key, res) end return res end tlds = sort_stuff(exports.values(tlds)) ntlds = #tlds -- Number of tlds < limit while limit > 0 do for _, lurls in ipairs(tlds) do if #lurls > 0 then local last = table.remove(lurls) insert_url(last[1], last[2]) limit = limit - 1 end if limit == 0 then break end end end res = exports.values(res) if params.task and not params.no_cache then params.task:cache_set(cache_key, res) end return res end --[[[ -- @function lua_util.extract_specific_urls(params) -- params: { - - task - - limit (default = 9999) - - esld_limit (default = 9999) n domains per eSLD (effective second level domain) works only if number of unique eSLD less than `limit` - - need_emails (default = false) - - filter (default = nil) - - prefix cache prefix (default = nil) - - ignore_redirected (default = false) - - need_images (default = false) - - need_content (default = false) -- } -- Apply heuristic in extracting of urls from task, this function -- tries its best to extract specific number of urls from a task based on -- their characteristics --]] -- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix) exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix) local default_params = { limit = 9999, esld_limit = 9999, need_emails = false, need_images = false, need_content = false, filter = nil, prefix = nil, ignore_ip = false, ignore_redirected = false, no_cache = false, } local params if type(params_or_task) == 'table' and type(lim) == 'nil' then params = params_or_task else -- Deprecated call params = { task = params_or_task, limit = lim, need_emails = need_emails, filter = filter, prefix = prefix } end for k, v in pairs(default_params) do if type(params[k]) == 'nil' and v ~= nil then params[k] = v end end local url_params = { emails = params.need_emails, images = params.need_images, content = params.need_content, flags = params.flags, -- maybe nil flags_mode = params.flags_mode, -- maybe nil } -- Shortcut for cached stuff if params.task and not params.no_cache then local cache_key if params.prefix then cache_key = params.prefix else local cache_key_suffix if params.flags then cache_key_suffix = table.concat(params.flags) .. (params.flags_mode or '') else cache_key_suffix = string.format('%s%s%s', tostring(params.need_emails or false), tostring(params.need_images or false), tostring(params.need_content or false)) end cache_key = string.format('sp_urls_%d%s', params.limit, cache_key_suffix) end local cached = params.task:cache_get(cache_key) if cached then return cached end end -- No cache version local urls = params.task:get_urls(url_params) return exports.filter_specific_urls(urls, params) end --[[[ -- @function lua_util.deepcopy(table) -- params: { - - table -- } -- Performs deep copy of the table. Including metatables --]] exports.deepcopy = deepcopy --[[[ -- @function lua_util.deepsort(table) -- params: { - - table -- } -- Performs recursive in-place sort of a table --]] local function default_sort_cmp(e1, e2) if type(e1) == type(e2) then return e1 < e2 else return type(e1) < type(e2) end end local function deepsort(tbl, sort_func) local orig_type = type(tbl) if orig_type == 'table' then table.sort(tbl, sort_func or default_sort_cmp) for _, orig_value in next, tbl, nil do deepsort(orig_value) end end end exports.deepsort = deepsort --[[[ -- @function lua_util.shallowcopy(tbl) -- Performs shallow (and fast) copy of a table or another Lua type --]] exports.shallowcopy = shallowcopy -- Debugging support local logger = require "rspamd_logger" local unconditional_debug = logger.log_level() == 'debug' local debug_modules = {} local debug_aliases = {} local log_level = 384 -- debug + forced (1 << 7 | 1 << 8) exports.init_debug_logging = function(config) -- Fill debug modules from the config if not unconditional_debug then local log_config = config:get_all_opt('logging') if log_config then local log_level_str = log_config.level if log_level_str then if log_level_str == 'debug' then unconditional_debug = true end end if log_config.debug_modules then for _, m in ipairs(log_config.debug_modules) do debug_modules[m] = true logger.infox(config, 'enable debug for Lua module %s', m) end end if #debug_aliases > 0 then for alias, mod in pairs(debug_aliases) do if debug_modules[mod] then debug_modules[alias] = true logger.infox(config, 'enable debug for Lua module %s (%s aliased)', alias, mod) end end end end end end exports.enable_debug_logging = function() unconditional_debug = true end exports.enable_debug_modules = function(...) for _, m in ipairs({ ... }) do debug_modules[m] = true end end exports.disable_debug_logging = function() unconditional_debug = false end --[[[ -- @function lua_util.debugm(module, [log_object], format, ...) -- Performs fast debug log for a specific module --]] exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...) if unconditional_debug or debug_modules[mod] then if type(obj_or_fmt) == 'string' then logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...) else logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...) end end end --[[[ -- @function lua_util.add_debug_alias(mod, alias) -- Add debugging alias so logging to `alias` will be treated as logging to `mod` --]] exports.add_debug_alias = function(mod, alias) debug_aliases[alias] = mod if debug_modules[mod] then debug_modules[alias] = true logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)', alias, mod) end end ---[[[ -- @function lua_util.get_task_verdict(task) -- Returns verdict for a task + score if certain, must be called from idempotent filters only -- Returns string: -- * `spam`: if message have over reject threshold and has more than one positive rule -- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules -- * `passthrough`: if a message has been passed through some short-circuit rule -- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score -- * `uncertain`: all other cases --]] exports.get_task_verdict = function(task) local lua_verdict = require "lua_verdict" return lua_verdict.get_default_verdict(task) end ---[[[ -- @function lua_util.maybe_obfuscate_string(subject, settings, prefix) -- Obfuscate string if enabled in settings. Also checks utf8 validity - if -- string is not valid utf8 then '???' is returned. Empty string returned as is. -- Supported settings: -- * _privacy = false - subject privacy is off -- * _privacy_alg = 'blake2' - default hash-algorithm to obfuscate subject -- * _privacy_prefix = 'obf' - prefix to show it's obfuscated -- * _privacy_length = 16 - cut the length of the hash; if 0 or fasle full hash is returned -- @return obfuscated or validated subject --]] exports.maybe_obfuscate_string = function(subject, settings, prefix) local hash = require 'rspamd_cryptobox_hash' if not subject or subject == '' then return subject elseif not rspamd_util.is_valid_utf8(subject) then subject = '???' elseif settings[prefix .. '_privacy'] then local hash_alg = settings[prefix .. '_privacy_alg'] or 'blake2' local subject_hash = hash.create_specific(hash_alg, subject) local strip_len = settings[prefix .. '_privacy_length'] if strip_len and strip_len > 0 then subject = subject_hash:hex():sub(1, strip_len) else subject = subject_hash:hex() end local privacy_prefix = settings[prefix .. '_privacy_prefix'] if privacy_prefix and #privacy_prefix > 0 then subject = privacy_prefix .. ':' .. subject end end return subject end ---[[[ -- @function lua_util.callback_from_string(str) -- Converts a string like `return function(...) end` to lua function and return true and this function -- or returns false + error message -- @return status code and function object or an error message --]]] exports.callback_from_string = function(s) local loadstring = loadstring or load if not s or #s == 0 then return false, 'invalid or empty string' end s = exports.rspamd_str_trim(s) local inp if s:match('^return%s*function') then -- 'return function', can be evaluated directly inp = s elseif s:match('^function%s*%(') then inp = 'return ' .. s else -- Just a plain sequence inp = 'return function(...)\n' .. s .. '; end' end local ret, res_or_err = pcall(loadstring(inp)) if not ret or type(res_or_err) ~= 'function' then return false, res_or_err end return ret, res_or_err end ---[[[ -- @function lua_util.keys(t) -- Returns all keys from a specific table -- @param {table} t input table (or iterator triplet) -- @return array of keys --]]] exports.keys = function(gen, param, state) local keys = {} local i = 1 if param then for k, _ in fun.iter(gen, param, state) do rawset(keys, i, k) i = i + 1 end else for k, _ in pairs(gen) do rawset(keys, i, k) i = i + 1 end end return keys end ---[[[ -- @function lua_util.values(t) -- Returns all values from a specific table -- @param {table} t input table -- @return array of values --]]] exports.values = function(gen, param, state) local values = {} local i = 1 if param then for _, v in fun.iter(gen, param, state) do rawset(values, i, v) i = i + 1 end else for _, v in pairs(gen) do rawset(values, i, v) i = i + 1 end end return values end ---[[[ -- @function lua_util.distance_sorted(t1, t2) -- Returns distance between two sorted tables t1 and t2 -- @param {table} t1 input table -- @param {table} t2 input table -- @return distance between `t1` and `t2` --]]] exports.distance_sorted = function(t1, t2) local ncomp = #t1 local ndiff = 0 local i, j = 1, 1 if ncomp < #t2 then ncomp = #t2 end for _ = 1, ncomp do if j > #t2 then ndiff = ndiff + ncomp - #t2 if i > j then ndiff = ndiff - (i - j) end break elseif i > #t1 then ndiff = ndiff + ncomp - #t1 if j > i then ndiff = ndiff - (j - i) end break end if t1[i] == t2[j] then i = i + 1 j = j + 1 elseif t1[i] < t2[j] then i = i + 1 ndiff = ndiff + 1 else j = j + 1 ndiff = ndiff + 1 end end return ndiff end ---[[[ -- @function lua_util.table_digest(t) -- Returns hash of all values if t[1] is string or all keys/values otherwise -- @param {table} t input array or map -- @return {string} base32 representation of blake2b hash of all strings --]]] local function table_digest(t) local cr = require "rspamd_cryptobox_hash" local h = cr.create() if t[1] then for _, e in ipairs(t) do if type(e) == 'table' then h:update(table_digest(e)) else h:update(tostring(e)) end end else for k, v in pairs(t) do h:update(tostring(k)) if type(v) == 'string' then h:update(v) elseif type(v) == 'table' then h:update(table_digest(v)) end end end return h:base32() end exports.table_digest = table_digest ---[[[ -- @function lua_util.toboolean(v) -- Converts a string or a number to boolean -- @param {string|number} v -- @return {boolean} v converted to boolean --]]] exports.toboolean = function(v) local true_t = { ['1'] = true, ['true'] = true, ['TRUE'] = true, ['True'] = true, }; local false_t = { ['0'] = false, ['false'] = false, ['FALSE'] = false, ['False'] = false, }; if type(v) == 'string' then if true_t[v] == true then return true; elseif false_t[v] == false then return false; else return false, string.format('cannot convert %q to boolean', v); end elseif type(v) == 'number' then return v ~= 0 else return false, string.format('cannot convert %q to boolean', v); end end ---[[[ -- @function lua_util.config_check_local_or_authed(config, modname) -- Reads check_local and check_authed from the config as this is used in many modules -- @param {rspamd_config} config `rspamd_config` global -- @param {name} module name -- @return {boolean} v converted to boolean --]]] exports.config_check_local_or_authed = function(rspamd_config, modname, def_local, def_authed) local check_local = def_local or false local check_authed = def_authed or false local function try_section(where) local ret = false local opts = rspamd_config:get_all_opt(where) if type(opts) == 'table' then if type(opts['check_local']) == 'boolean' then check_local = opts['check_local'] ret = true end if type(opts['check_authed']) == 'boolean' then check_authed = opts['check_authed'] ret = true end end return ret end if not try_section(modname) then try_section('options') end return { check_local, check_authed } end ---[[[ -- @function lua_util.is_skip_local_or_authed(task, conf[, ip]) -- Returns `true` if local or authenticated task should be skipped for this module -- @param {rspamd_task} task -- @param {table} conf table returned from `config_check_local_or_authed` -- @param {rspamd_ip} ip optional ip address (can be obtained from a task) -- @return {boolean} true if check should be skipped --]]] exports.is_skip_local_or_authed = function(task, conf, ip) if not ip then ip = task:get_from_ip() end if not conf then conf = { false, false } end if ((not conf[2] and task:get_user()) or (not conf[1] and type(ip) == 'userdata' and ip:is_local())) then return true end return false end ---[[[ -- @function lua_util.maybe_smtp_quote_value(str) -- Checks string for the forbidden elements (tspecials in RFC and quote string if needed) -- @param {string} str input string -- @return {string} original or quoted string --]]] local tspecial = lpeg.S "()<>,;:\\\"/[]?= \t\v" local special_match = lpeg.P((1 - tspecial) ^ 0 * tspecial ^ 1) exports.maybe_smtp_quote_value = function(str) if special_match:match(str) then return string.format('"%s"', str:gsub('"', '\\"')) end return str end ---[[[ -- @function lua_util.shuffle(table) -- Performs in-place shuffling of a table -- @param {table} tbl table to shuffle -- @return {table} same table --]]] exports.shuffle = function(tbl) local size = #tbl for i = size, 1, -1 do local rand = math.random(size) tbl[i], tbl[rand] = tbl[rand], tbl[i] end return tbl end -- local hex_table = {} for idx = 0, 255 do hex_table[("%02X"):format(idx)] = string.char(idx) hex_table[("%02x"):format(idx)] = string.char(idx) end ---[[[ -- @function lua_util.unhex(str) -- Decode hex encoded string -- @param {string} str string to decode -- @return {string} hex decoded string (valid hex pairs are decoded, everything else is printed as is) --]]] exports.unhex = function(str) return str:gsub('(..)', hex_table) end local http_upstream_lists = {} local function http_upstreams_by_url(pool, url) local rspamd_url = require "rspamd_url" local cached = http_upstream_lists[url] if cached then return cached end local real_url = rspamd_url.create(pool, url) if not real_url then return nil end local host = real_url:get_host() local proto = real_url:get_protocol() or 'http' local port = real_url:get_port() or (proto == 'https' and 443 or 80) local upstream_list = require "rspamd_upstream_list" local upstreams = upstream_list.create(host, port) if upstreams then http_upstream_lists[url] = upstreams return upstreams end return nil end ---[[[ -- @function lua_util.http_upstreams_by_url(pool, url) -- Returns a cached or new upstreams list that corresponds to the specific url -- @param {mempool} pool memory pool to use (typically static pool from rspamd_config) -- @param {string} url full url -- @return {upstreams_list} object to get upstream from an url --]]] exports.http_upstreams_by_url = http_upstreams_by_url ---[[[ -- @function lua_util.dns_timeout_augmentation(cfg) -- Returns an augmentation suitable to define DNS timeout for a module -- @return {string} a string in format 'timeout=x' where `x` is a number of seconds for DNS timeout --]]] local function dns_timeout_augmentation(cfg) return string.format('timeout=%f', cfg:get_dns_timeout() or 0.0) end exports.dns_timeout_augmentation = dns_timeout_augmentation ---[[[ --- @function lua_util.strip_lua_comments(lua_code) -- Strips single-line and multi-line comments from a given Lua code string and removes -- any extra spaces or newlines. -- -- @param lua_code The Lua code string to strip comments from. -- @return The resulting Lua code string with comments and extra spaces removed. -- ---]]] local function strip_lua_comments(lua_code) -- Remove single-line comments lua_code = lua_code:gsub("%-%-[^\r\n]*", "") -- Remove multi-line comments lua_code = lua_code:gsub("%-%-%[%[.-%]%]", "") -- Remove extra spaces and newlines lua_code = lua_code:gsub("%s+", " ") return lua_code end exports.strip_lua_comments = strip_lua_comments ---[[[ -- @function lua_util.join_path(...) -- Joins path components into a single path string using the appropriate separator -- for the current operating system. -- -- @param ... Any number of path components to join together. -- @return A single path string, with components separated by the appropriate separator. -- ---]]] local path_sep = package.config:sub(1, 1) or '/' local function join_path(...) local components = { ... } -- Join components using separator return table.concat(components, path_sep) end exports.join_path = join_path -- Short unit test for sanity if path_sep == '/' then assert(join_path('/path', 'to', 'file') == '/path/to/file') else assert(join_path('C:', 'path', 'to', 'file') == 'C:\\path\\to\\file') end -- Defines symbols priorities for common usage in prefilters/postfilters exports.symbols_priorities = { top = 10, -- Symbols must be executed first (or last), such as settings high = 9, -- Example: asn medium = 5, -- Everything should use this as default low = 0, } return exports