diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-08-07 11:25:52 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2023-08-07 11:25:52 +0100 |
commit | bbd88232db43d18f5e0de5a6502848d4074621c5 (patch) | |
tree | 32682e9f044704d0456575d6058735c60fa960ac | |
parent | ffbab4fbf218514845b8e5209aec044621b1f460 (diff) | |
download | rspamd-bbd88232db43d18f5e0de5a6502848d4074621c5.tar.gz rspamd-bbd88232db43d18f5e0de5a6502848d4074621c5.zip |
[Minor] Distinguish failures from unknown errors
-rw-r--r-- | lualib/lua_util.lua | 229 | ||||
-rw-r--r-- | src/lua/lua_common.c | 8 | ||||
-rw-r--r-- | src/plugins/lua/reputation.lua | 221 |
3 files changed, 265 insertions, 193 deletions
diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua index 289e2ed3a..16d009619 100644 --- a/lualib/lua_util.lua +++ b/lualib/lua_util.lua @@ -1,5 +1,5 @@ --[[ -Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,9 +27,9 @@ local lupa = require "lupa" local split_grammar = {} local spaces_split_grammar -local space = lpeg.S' \t\n\v\f\r' +local space = lpeg.S ' \t\n\v\f\r' local nospace = 1 - space -local ptrim = space^0 * lpeg.C((space^0 * nospace^1)^0) +local ptrim = space ^ 0 * lpeg.C((space ^ 0 * nospace ^ 1) ^ 0) local match = lpeg.match lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', { @@ -47,8 +47,8 @@ local function rspamd_str_split(s, sep) if not sep then if not spaces_split_grammar then local _sep = space - local elem = lpeg.C((1 - _sep)^0) - local p = lpeg.Ct(elem * (_sep * elem)^0) + local elem = lpeg.C((1 - _sep) ^ 0) + local p = lpeg.Ct(elem * (_sep * elem) ^ 0) spaces_split_grammar = p end @@ -63,8 +63,8 @@ local function rspamd_str_split(s, sep) else _sep = sep -- Assume lpeg object end - local elem = lpeg.C((1 - _sep)^0) - local p = lpeg.Ct(elem * (_sep * elem)^0) + local elem = lpeg.C((1 - _sep) ^ 0) + local p = lpeg.Ct(elem * (_sep * elem) ^ 0) gr = p split_grammar[sep] = gr end @@ -126,7 +126,7 @@ end -- modified version from Robert Jay Gould http://lua-users.org/wiki/SimpleRound exports.round = function(num, numDecimalPlaces) - local mult = 10^(numDecimalPlaces or 0) + local mult = 10 ^ (numDecimalPlaces or 0) if num >= 0 then return math.floor(num * mult + 0.5) / mult else @@ -148,10 +148,10 @@ end exports.template = function(tmpl, keys) local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } - local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit^1) / keys) } - local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit^1) / keys) * (lpeg.P("}") / "") } + local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit ^ 1) / keys) } + local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit ^ 1) / keys) * (lpeg.P("}") / "") } - local template_grammar = lpeg.Cs((var + var_braced + 1)^0) + local template_grammar = lpeg.Cs((var + var_braced + 1) ^ 0) return lpeg.match(template_grammar, tmpl) end @@ -209,7 +209,7 @@ exports.remove_email_aliases = function(email_addr) if cap then return cap, rspamd_str_split(pluses, '+'), nil elseif no_dots_user ~= addr.user then - return no_dots_user,{},nil + return no_dots_user, {}, nil end return nil @@ -298,7 +298,9 @@ exports.is_rspamc_or_controller = function(task) local ua = task:get_request_header('User-Agent') or '' local pwd = task:get_request_header('Password') local is_rspamc = false - if tostring(ua) == 'rspamc' or pwd then is_rspamc = true end + if tostring(ua) == 'rspamc' or pwd then + is_rspamc = true + end return is_rspamc end @@ -324,8 +326,8 @@ end --]] exports.flatten = function(t) local res = {} - for _,e in fun.iter(t) do - for _,v in fun.iter(e) do + for _, e in fun.iter(t) do + for _, v in fun.iter(e) do res[#res + 1] = v end end @@ -350,12 +352,16 @@ end local function spairs(t, order, lim) -- collect the keys local keys = {} - for k in pairs(t) do keys[#keys+1] = k end + for k in pairs(t) do + keys[#keys + 1] = k + end -- if order function given, sort by it by passing the table and keys a, b, -- otherwise just sort the keys if order then - table.sort(keys, function(a,b) return order(t, a, b) end) + table.sort(keys, function(a, b) + return order(t, a, b) + end) else table.sort(keys) end @@ -375,13 +381,13 @@ end exports.spairs = spairs --[[[ --- @function lua_util.disable_module(modname, how) +-- @function lua_util.disable_module(modname, how[, reason]) -- Disables a plugin -- @param {string} modname name of plugin to disable -- @param {string} how 'redis' to disable redis, 'config' to disable startup +-- @param {string} reason optional reason for failure --]] - -local function disable_module(modname, how) +local function disable_module(modname, how, reason) if rspamd_plugins_state.enabled[modname] then rspamd_plugins_state.enabled[modname] = nil end @@ -392,8 +398,10 @@ local function disable_module(modname, how) rspamd_plugins_state.disabled_unconfigured[modname] = {} elseif how == 'experimental' then rspamd_plugins_state.disabled_experimental[modname] = {} + elseif how == 'failed' then + rspamd_plugins_state.disabled_failed[modname] = { reason = reason } else - rspamd_plugins_state.disabled_failed[modname] = {} + rspamd_plugins_state.disabled_unknown[modname] = {} end end @@ -461,9 +469,13 @@ exports.list_to_hash = list_to_hash local function nkeys(gen, param, state) local n = 0 if not param then - for _,_ in pairs(gen) do n = n + 1 end + for _, _ in pairs(gen) do + n = n + 1 + end else - for _,_ in fun.iter(gen, param, state) do n = n + 1 end + for _, _ in fun.iter(gen, param, state) do + n = n + 1 + end end return n end @@ -497,20 +509,19 @@ local function parse_time_interval(str) local digit = lpeg.R("09") local parser = {} - parser.integer = - (lpeg.S("+-") ^ -1) * - (digit ^ 1) - parser.fractional = - (lpeg.P(".") ) * + parser.integer = (lpeg.S("+-") ^ -1) * (digit ^ 1) - parser.number = - (parser.integer * + parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + parser.number = (parser.integer * (parser.fractional ^ -1)) + (lpeg.S("+-") * parser.fractional) parser.time = lpeg.Cf(lpeg.Cc(1) * (parser.number / tonumber) * ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1), - function (acc, val) return acc * val end) + function(acc, val) + return acc * val + end) local t = lpeg.match(parser.time, str) @@ -546,20 +557,19 @@ local function dehumanize_number(str) local digit = lpeg.R("09") local parser = {} - parser.integer = - (lpeg.S("+-") ^ -1) * - (digit ^ 1) - parser.fractional = - (lpeg.P(".") ) * + parser.integer = (lpeg.S("+-") ^ -1) * (digit ^ 1) - parser.number = - (parser.integer * + parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + parser.number = (parser.integer * (parser.fractional ^ -1)) + (lpeg.S("+-") * parser.fractional) parser.humanized_number = lpeg.Cf(lpeg.Cc(1) * (parser.number / tonumber) * (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1), - function (acc, val) return acc * val end) + function(acc, val) + return acc * val + end) local t = lpeg.match(parser.humanized_number, str) @@ -575,16 +585,24 @@ exports.dehumanize_number = dehumanize_number local function table_cmp(table1, table2) local avoid_loops = {} local function recurse(t1, t2) - if type(t1) ~= type(t2) then return false end - if type(t1) ~= "table" then return t1 == t2 end + if type(t1) ~= type(t2) then + return false + end + if type(t1) ~= "table" then + return t1 == t2 + end - if avoid_loops[t1] then return avoid_loops[t1] == t2 end + if avoid_loops[t1] then + return avoid_loops[t1] == t2 + end avoid_loops[t1] = t2 -- Copy keys from t2 local t2keys = {} local t2tablekeys = {} for k, _ in pairs(t2) do - if type(k) == "table" then table.insert(t2tablekeys, k) end + if type(k) == "table" then + table.insert(t2tablekeys, k) + end t2keys[k] = true end -- Let's iterate keys from t1 @@ -601,16 +619,24 @@ local function table_cmp(table1, table2) break end end - if not ok then return false end + if not ok then + return false + end else -- t1 has a key which t2 doesn't have, fail. - if v2 == nil then return false end + if v2 == nil then + return false + end t2keys[k1] = nil - if not recurse(v1, v2) then return false end + if not recurse(v1, v2) then + return false + end end end -- if t2 has a key which t1 doesn't have, fail. - if next(t2keys) then return false end + if next(t2keys) then + return false + end return true end return recurse(table1, table2) @@ -663,7 +689,7 @@ local function override_defaults(def, override) local res = {} - for k,v in pairs(override) do + for k, v in pairs(override) do if type(v) == 'table' then if def[k] and type(def[k]) == 'table' then -- Recursively override elements @@ -676,7 +702,7 @@ local function override_defaults(def, override) end end - for k,v in pairs(def) do + for k, v in pairs(def) do if type(res[k]) == 'nil' then res[k] = v end @@ -702,7 +728,7 @@ exports.override_defaults = override_defaults -- tries its best to extract specific number of urls from a task based on -- their characteristics --]] -exports.filter_specific_urls = function (urls, params) +exports.filter_specific_urls = function(urls, params) local cache_key if params.task and not params.no_cache then @@ -721,9 +747,13 @@ exports.filter_specific_urls = function (urls, params) end end - if not urls then return {} end + if not urls then + return {} + end - if params.filter then urls = fun.totable(fun.filter(params.filter, urls)) end + if params.filter then + urls = fun.totable(fun.filter(params.filter, urls)) + end -- Filter by tld: local tlds = {} @@ -794,11 +824,11 @@ exports.filter_specific_urls = function (urls, params) end if not eslds[esld] then - eslds[esld] = {{str_hash, u, priority}} + eslds[esld] = { { str_hash, u, priority } } neslds = neslds + 1 else if #eslds[esld] < params.esld_limit then - table.insert(eslds[esld], {str_hash, u, priority}) + table.insert(eslds[esld], { str_hash, u, priority }) end end @@ -808,21 +838,23 @@ exports.filter_specific_urls = function (urls, params) local tld = table.concat(fun.totable(fun.tail(parts)), '.') if not tlds[tld] then - tlds[tld] = {{str_hash, u, priority}} + tlds[tld] = { { str_hash, u, priority } } ntlds = ntlds + 1 else - table.insert(tlds[tld], {str_hash, u, priority}) + table.insert(tlds[tld], { str_hash, u, priority }) end end end - for _,u in ipairs(urls) do + for _, u in ipairs(urls) do process_single_url(u) end local limit = params.limit limit = limit - nres - if limit < 0 then limit = 0 end + if limit < 0 then + limit = 0 + end if limit == 0 then res = exports.values(res) @@ -865,7 +897,7 @@ exports.filter_specific_urls = function (urls, params) repeat local item_found = false - for _,lurls in ipairs(eslds) do + for _, lurls in ipairs(eslds) do if #lurls > 0 then local last = table.remove(lurls) insert_url(last[1], last[2]) @@ -888,13 +920,15 @@ exports.filter_specific_urls = function (urls, params) -- Number of tlds < limit while limit > 0 do - for _,lurls in ipairs(tlds) do + for _, lurls in ipairs(tlds) do if #lurls > 0 then local last = table.remove(lurls) insert_url(last[1], last[2]) limit = limit - 1 end - if limit == 0 then break end + if limit == 0 then + break + end end end @@ -951,8 +985,10 @@ exports.extract_specific_urls = function(params_or_task, lim, need_emails, filte prefix = prefix } end - for k,v in pairs(default_params) do - if type(params[k]) == 'nil' and v ~= nil then params[k] = v end + for k, v in pairs(default_params) do + if type(params[k]) == 'nil' and v ~= nil then + params[k] = v + end end local url_params = { emails = params.need_emails, @@ -973,9 +1009,9 @@ exports.extract_specific_urls = function(params_or_task, lim, need_emails, filte cache_key_suffix = table.concat(params.flags) .. (params.flags_mode or '') else cache_key_suffix = string.format('%s%s%s', - tostring(params.need_emails or false), - tostring(params.need_images or false), - tostring(params.need_content or false)) + tostring(params.need_emails or false), + tostring(params.need_images or false), + tostring(params.need_content or false)) end cache_key = string.format('sp_urls_%d%s', params.limit, cache_key_suffix) end @@ -1010,7 +1046,8 @@ local function deepcopy(orig) if getmetatable(orig) then setmetatable(copy, deepcopy(getmetatable(orig))) end - else -- number, string, boolean, etc + else + -- number, string, boolean, etc copy = orig end return copy @@ -1083,14 +1120,14 @@ exports.init_debug_logging = function(config) end end if log_config.debug_modules then - for _,m in ipairs(log_config.debug_modules) do + for _, m in ipairs(log_config.debug_modules) do debug_modules[m] = true logger.infox(config, 'enable debug for Lua module %s', m) end end if #debug_aliases > 0 then - for alias,mod in pairs(debug_aliases) do + for alias, mod in pairs(debug_aliases) do if debug_modules[mod] then debug_modules[alias] = true logger.infox(config, 'enable debug for Lua module %s (%s aliased)', @@ -1107,7 +1144,7 @@ exports.enable_debug_logging = function() end exports.enable_debug_modules = function(...) - for _,m in ipairs({...}) do + for _, m in ipairs({ ... }) do debug_modules[m] = true end end @@ -1207,7 +1244,7 @@ exports.callback_from_string = function(s) local loadstring = loadstring or load if not s or #s == 0 then - return false,'invalid or empty string' + return false, 'invalid or empty string' end s = exports.rspamd_str_trim(s) @@ -1226,10 +1263,10 @@ exports.callback_from_string = function(s) local ret, res_or_err = pcall(loadstring(inp)) if not ret or type(res_or_err) ~= 'function' then - return false,res_or_err + return false, res_or_err end - return ret,res_or_err + return ret, res_or_err end ---[[[ @@ -1243,12 +1280,12 @@ exports.keys = function(gen, param, state) local i = 1 if param then - for k,_ in fun.iter(gen, param, state) do + for k, _ in fun.iter(gen, param, state) do rawset(keys, i, k) i = i + 1 end else - for k,_ in pairs(gen) do + for k, _ in pairs(gen) do rawset(keys, i, k) i = i + 1 end @@ -1268,12 +1305,12 @@ exports.values = function(gen, param, state) local i = 1 if param then - for _,v in fun.iter(gen, param, state) do + for _, v in fun.iter(gen, param, state) do rawset(values, i, v) i = i + 1 end else - for _,v in pairs(gen) do + for _, v in pairs(gen) do rawset(values, i, v) i = i + 1 end @@ -1292,13 +1329,13 @@ end exports.distance_sorted = function(t1, t2) local ncomp = #t1 local ndiff = 0 - local i,j = 1,1 + local i, j = 1, 1 if ncomp < #t2 then ncomp = #t2 end - for _=1,ncomp do + for _ = 1, ncomp do if j > #t2 then ndiff = ndiff + ncomp - #t2 if i > j then @@ -1339,7 +1376,7 @@ local function table_digest(t) local h = cr.create() if t[1] then - for _,e in ipairs(t) do + for _, e in ipairs(t) do if type(e) == 'table' then h:update(table_digest(e)) else @@ -1347,7 +1384,7 @@ local function table_digest(t) end end else - for k,v in pairs(t) do + for k, v in pairs(t) do h:update(tostring(k)) if type(v) == 'string' then @@ -1357,7 +1394,7 @@ local function table_digest(t) end end end - return h:base32() + return h:base32() end exports.table_digest = table_digest @@ -1388,12 +1425,12 @@ exports.toboolean = function(v) elseif false_t[v] == false then return false; else - return false, string.format( 'cannot convert %q to boolean', v); + return false, string.format('cannot convert %q to boolean', v); end elseif type(v) == 'number' then return v ~= 0 else - return false, string.format( 'cannot convert %q to boolean', v); + return false, string.format('cannot convert %q to boolean', v); end end @@ -1429,7 +1466,7 @@ exports.config_check_local_or_authed = function(rspamd_config, modname, def_loca try_section('options') end - return {check_local, check_authed} + return { check_local, check_authed } end ---[[[ @@ -1445,7 +1482,7 @@ exports.is_skip_local_or_authed = function(task, conf, ip) ip = task:get_from_ip() end if not conf then - conf = {false, false} + conf = { false, false } end if ((not conf[2] and task:get_user()) or (not conf[1] and type(ip) == 'userdata' and ip:is_local())) then @@ -1461,8 +1498,8 @@ end -- @param {string} str input string -- @return {string} original or quoted string --]]] -local tspecial = lpeg.S"()<>,;:\\\"/[]?= \t\v" -local special_match = lpeg.P((1 - tspecial)^0 * tspecial^1) +local tspecial = lpeg.S "()<>,;:\\\"/[]?= \t\v" +local special_match = lpeg.P((1 - tspecial) ^ 0 * tspecial ^ 1) exports.maybe_smtp_quote_value = function(str) if special_match:match(str) then return string.format('"%s"', str:gsub('"', '\\"')) @@ -1499,18 +1536,24 @@ end -- @param {string} str string to decode -- @return {string} hex decoded string (valid hex pairs are decoded, everything else is printed as is) --]]] -exports.unhex = function(str) return str:gsub('(..)', hex_table) end +exports.unhex = function(str) + return str:gsub('(..)', hex_table) +end local http_upstream_lists = {} local function http_upstreams_by_url(pool, url) local rspamd_url = require "rspamd_url" local cached = http_upstream_lists[url] - if cached then return cached end + if cached then + return cached + end local real_url = rspamd_url.create(pool, url) - if not real_url then return nil end + if not real_url then + return nil + end local host = real_url:get_host() local proto = real_url:get_protocol() or 'http' @@ -1578,9 +1621,9 @@ exports.strip_lua_comments = strip_lua_comments -- @return A single path string, with components separated by the appropriate separator. -- ---]]] -local path_sep = package.config:sub(1,1) or '/' +local path_sep = package.config:sub(1, 1) or '/' local function join_path(...) - local components = {...} + local components = { ... } -- Join components using separator return table.concat(components, path_sep) diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c index b543ae5db..53473c9dc 100644 --- a/src/lua/lua_common.c +++ b/src/lua/lua_common.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2023 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -1008,6 +1008,7 @@ rspamd_lua_init(bool wipe_mem) * disabled_explicitly = {}, * disabled_failed = {}, * disabled_experimental = {}, + * disabled_unknown = {}, * } */ #define ADD_TABLE(name) \ @@ -1023,6 +1024,7 @@ rspamd_lua_init(bool wipe_mem) ADD_TABLE(disabled_explicitly); ADD_TABLE(disabled_failed); ADD_TABLE(disabled_experimental); + ADD_TABLE(disabled_unknown); #undef ADD_TABLE lua_setglobal(L, rspamd_modules_state_global); diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua index e0530a42e..2fc1c3ad5 100644 --- a/src/plugins/lua/reputation.lua +++ b/src/plugins/lua/reputation.lua @@ -70,7 +70,7 @@ end local function add_symbol_score(task, rule, mult, params) if not params then - params = {tostring(mult)} + params = { tostring(mult) } end if rule.selector.config.split_symbols then @@ -124,9 +124,11 @@ end -- Extracts task score and subtracts score of the rule itself local function extract_task_score(task, rule) local lua_verdict = require "lua_verdict" - local verdict,score = lua_verdict.get_specific_verdict(N, task) + local verdict, score = lua_verdict.get_specific_verdict(N, task) - if not score or verdict == 'passthrough' then return nil end + if not score or verdict == 'passthrough' then + return nil + end return sub_symbol_score(task, rule, score) end @@ -140,23 +142,25 @@ local function gen_dkim_queries(task, rule) if not gr then local semicolon = lpeg.P(':') - local domain = lpeg.C((1 - semicolon)^1) - local res = lpeg.S'+-?~' + local domain = lpeg.C((1 - semicolon) ^ 1) + local res = lpeg.S '+-?~' local function res_to_label(ch) - if ch == '+' then return 'a' - elseif ch == '-' then return 'r' + if ch == '+' then + return 'a' + elseif ch == '-' then + return 'r' end return 'u' end - gr = domain * semicolon * (lpeg.C(res^1) / res_to_label) + gr = domain * semicolon * (lpeg.C(res ^ 1) / res_to_label) end if dkim_trace and dkim_trace.options then - for _,opt in ipairs(dkim_trace.options) do - local dom,res = lpeg.match(gr, opt) + for _, opt in ipairs(dkim_trace.options) do + local dom, res = lpeg.match(gr, opt) if dom and res then local tld = rspamd_util.get_tld(dom) @@ -185,11 +189,11 @@ local function dkim_reputation_filter(task, rule) end if requests_left == 0 then - for k,v in pairs(results) do + for k, v in pairs(results) do -- `k` in results is a prefixed and suffixed tld, so we need to look through -- all requests to find any request with the matching tld local sel_tld - for _,tld in ipairs(dkim_tlds) do + for _, tld in ipairs(dkim_tlds) do if k:find(tld, 1, true) then sel_tld = tld break @@ -212,8 +216,12 @@ local function dkim_reputation_filter(task, rule) rep_accepted_abs) if rep_accepted_abs then local final_rep = rep_accepted - if rep_accepted > 1.0 then final_rep = 1.0 end - if rep_accepted < -1.0 then final_rep = -1.0 end + if rep_accepted > 1.0 then + final_rep = 1.0 + end + if rep_accepted < -1.0 then + final_rep = -1.0 + end add_symbol_score(task, rule, final_rep) -- Store results for future DKIM results adjustments @@ -222,7 +230,7 @@ local function dkim_reputation_filter(task, rule) end end - for dom,res in pairs(requests) do + for dom, res in pairs(requests) do -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs local query = string.format('%s.%s', dom, res) rule.backend.get_token(task, rule, nil, query, tokens_cb, 'string') @@ -234,7 +242,7 @@ local function dkim_reputation_idempotent(task, rule) local sc = extract_task_score(task, rule) if sc then - for dom,res in pairs(requests) do + for dom, res in pairs(requests) do -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs local query = string.format('%s.%s', dom, res) rule.backend.set_token(task, rule, nil, query, sc) @@ -270,7 +278,7 @@ local dkim_selector = { inbound = true, max_accept_adjustment = 2.0, -- How to adjust accepted DKIM score }, - dependencies = {"DKIM_TRACE"}, + dependencies = { "DKIM_TRACE" }, filter = dkim_reputation_filter, -- used to get scores postfilter = dkim_reputation_postfilter, -- used to adjust DKIM scores idempotent = dkim_reputation_idempotent, -- used to set scores @@ -295,14 +303,18 @@ local function gen_url_queries(task, rule) else domains[dom] = domains[dom] + 1 end - end, fun.filter(function(u) return not u:is_html_displayed() end, - task:get_urls(true))) + end, fun.filter(function(u) + return not u:is_html_displayed() + end, + task:get_urls(true))) local results = {} - for k,v in lua_util.spairs(domains, - function(t, a, b) return t[a] > t[b] end, rule.selector.config.max_urls) do + for k, v in lua_util.spairs(domains, + function(t, a, b) + return t[a] > t[b] + end, rule.selector.config.max_urls) do if v > 0 then - table.insert(results, {k,v}) + table.insert(results, { k, v }) end end @@ -326,7 +338,7 @@ local function url_reputation_filter(task, rule) -- Check the url with maximum hits local mhits = 0 - for i,res in ipairs(results) do + for i, res in ipairs(results) do local req = requests[i] if req then local hits = tonumber(res[1]) @@ -341,7 +353,7 @@ local function url_reputation_filter(task, rule) if mhits > 0 then local score = 0 - for i,res in pairs(results) do + for i, res in pairs(results) do local req = requests[i] if req then local url_score = generic_reputation_calc(res, rule, @@ -359,7 +371,7 @@ local function url_reputation_filter(task, rule) end end - for i,req in ipairs(requests) do + for i, req in ipairs(requests) do local function tokens_cb(err, token, values) indexed_tokens_cb(err, i, values) end @@ -373,7 +385,7 @@ local function url_reputation_idempotent(task, rule) local sc = extract_task_score(task, rule) if sc then - for _,tld in ipairs(requests) do + for _, tld in ipairs(requests) do rule.backend.set_token(task, rule, nil, tld[1], sc) end end @@ -400,9 +412,9 @@ local function ip_reputation_init(rule) if cfg.asn_cc_whitelist then cfg.asn_cc_whitelist = lua_maps.map_add('reputation', - 'asn_cc_whitelist', - 'map', - 'IP score whitelisted ASNs/countries') + 'asn_cc_whitelist', + 'map', + 'IP score whitelisted ASNs/countries') end return true @@ -412,8 +424,12 @@ local function ip_reputation_filter(task, rule) local ip = task:get_from_ip() - if not ip or not ip:is_valid() then return end - if lua_util.is_rspamc_or_controller(task) then return end + if not ip or not ip:is_valid() then + return + end + if lua_util.is_rspamc_or_controller(task) then + return + end local cfg = rule.selector.config @@ -451,21 +467,21 @@ local function ip_reputation_filter(task, rule) local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn, task) score = score + asn_score table.insert(description_t, string.format('asn: %s(%.2f)', - asn, asn_score)) + asn, asn_score)) end if country_stats then local country_score = generic_reputation_calc(country_stats, rule, cfg.scores.country, task) score = score + country_score table.insert(description_t, string.format('country: %s(%.2f)', - country, country_score)) + country, country_score)) end if ip_stats then local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip, - task) + task) score = score + ip_score table.insert(description_t, string.format('ip: %s(%.2f)', - tostring(ip), ip_score)) + tostring(ip), ip_score)) end if math.abs(score) > 0.001 then @@ -505,26 +521,32 @@ local function ip_reputation_filter(task, rule) if asn then rule.backend.get_token(task, rule, cfg.asn_prefix, asn, - gen_token_callback('asn'), 'string') + gen_token_callback('asn'), 'string') end if country then rule.backend.get_token(task, rule, cfg.country_prefix, country, - gen_token_callback('country'), 'string') + gen_token_callback('country'), 'string') end rule.backend.get_token(task, rule, cfg.ip_prefix, ip, - gen_token_callback('ip'), 'ip') + gen_token_callback('ip'), 'ip') end -- Used to set scores local function ip_reputation_idempotent(task, rule) - if not rule.backend.set_token then return end -- Read only backend + if not rule.backend.set_token then + return + end -- Read only backend local ip = task:get_from_ip() local cfg = rule.selector.config - if not ip or not ip:is_valid() then return end + if not ip or not ip:is_valid() then + return + end - if lua_util.is_rspamc_or_controller(task) then return end + if lua_util.is_rspamc_or_controller(task) then + return + end if ip:get_version() == 4 and cfg.ipv4_mask then ip = ip:apply_mask(cfg.ipv4_mask) @@ -592,7 +614,9 @@ local function spf_reputation_filter(task, rule) local spf_allow = task:has_symbol('R_SPF_ALLOW') -- Don't care about bad/missing spf - if not spf_record or not spf_allow then return end + if not spf_record or not spf_allow then + return + end local cr = require "rspamd_cryptobox_hash" local hkey = cr.create(spf_record):base32():sub(1, 32) @@ -618,7 +642,9 @@ local function spf_reputation_idempotent(task, rule) local spf_record = task:get_mempool():get_variable('spf_record') local spf_allow = task:has_symbol('R_SPF_ALLOW') - if not spf_record or not spf_allow or not sc then return end + if not spf_record or not spf_allow or not sc then + return + end local cr = require "rspamd_cryptobox_hash" local hkey = cr.create(spf_record):base32():sub(1, 32) @@ -628,7 +654,6 @@ local function spf_reputation_idempotent(task, rule) rule.backend.set_token(task, rule, nil, hkey, sc) end - local spf_selector = { config = { symbol = 'SPF_REP', -- symbol to be inserted @@ -639,7 +664,7 @@ local spf_selector = { outbound = true, inbound = true, }, - dependencies = {"R_SPF_ALLOW"}, + dependencies = { "R_SPF_ALLOW" }, filter = spf_reputation_filter, -- used to get scores idempotent = spf_reputation_idempotent, -- used to set scores } @@ -694,12 +719,12 @@ local function generic_reputation_filter(task, rule) if type(selector_res) == 'table' then fun.each(function(e) lua_util.debugm(N, task, 'check generic reputation (%s) %s', - rule['symbol'], e) + rule['symbol'], e) rule.backend.get_token(task, rule, nil, e, tokens_cb, 'string') end, selector_res) else lua_util.debugm(N, task, 'check generic reputation (%s) %s', - rule['symbol'], selector_res) + rule['symbol'], selector_res) rule.backend.get_token(task, rule, nil, selector_res, tokens_cb, 'string') end end @@ -710,7 +735,9 @@ local function generic_reputation_idempotent(task, rule) local cfg = rule.selector.config local selector_res = cfg.selector(task) - if not selector_res then return end + if not selector_res then + return + end if sc then if type(selector_res) == 'table' then @@ -727,9 +754,8 @@ local function generic_reputation_idempotent(task, rule) end end - local generic_selector = { - schema = ts.shape{ + schema = ts.shape { lower_bound = ts.number + ts.string / tonumber, max_score = ts.number:is_optional(), min_score = ts.number:is_optional(), @@ -754,8 +780,6 @@ local generic_selector = { idempotent = generic_reputation_idempotent -- used to set scores } - - local selectors = { ip = ip_selector, sender = ip_selector, -- Better name @@ -768,14 +792,13 @@ local selectors = { local function reputation_dns_init(rule, _, _, _) if not rule.backend.config.list then rspamd_logger.errx(rspamd_config, "rule %s with DNS backend has no `list` parameter defined", - rule.symbol) + rule.symbol) return false end return true end - local function gen_token_key(prefix, token, rule) if prefix then token = prefix .. token @@ -843,7 +866,7 @@ local function reputation_dns_get_token(task, rule, prefix, token, continuation_ if prefix then dns_name = string.format('%s.%s.%s', key, prefix, - rule.backend.config.list) + rule.backend.config.list) else dns_name = string.format('%s.%s', key, rule.backend.config.list) end @@ -858,22 +881,24 @@ local function reputation_dns_get_token(task, rule, prefix, token, continuation_ dns_name, results, err, rule.backend.config.list) -- Now split tokens to list of values - if results and results[1] then + if results and results[1] then -- Format: num_messages;sc1;sc2...scn local dns_tokens = lua_util.rspamd_str_split(results[1], ";") -- Convert all to numbers excluding any possible non-numbers dns_tokens = fun.totable(fun.filter(function(e) return type(e) == 'number' end, - fun.map(function(e) - local n = tonumber(e) - if n then return n end - return "BAD" - end, dns_tokens))) + fun.map(function(e) + local n = tonumber(e) + if n then + return n + end + return "BAD" + end, dns_tokens))) if #dns_tokens < 2 then rspamd_logger.warnx(task, 'cannot parse response for reputation token %s: %s', - dns_name, results[1]) + dns_name, results[1]) continuation_cb(results, dns_name, nil) else local cnt = table.remove(dns_tokens, 1) @@ -881,12 +906,12 @@ local function reputation_dns_get_token(task, rule, prefix, token, continuation_ end else rspamd_logger.messagex(task, 'invalid response for reputation token %s: %s', - dns_name, results[1]) + dns_name, results[1]) continuation_cb(results, dns_name, nil) end end - task:get_resolver():resolve_a({ + task:get_resolver():resolve_a({ task = task, name = dns_name, callback = dns_cb, @@ -929,7 +954,7 @@ local function reputation_redis_init(rule, cfg, ev_base, worker) ]] local get_script = lua_util.jinja_template(redis_get_script_tpl, - {windows = rule.backend.config.buckets}) + { windows = rule.backend.config.buckets }) rspamd_logger.debugm(N, rspamd_config, 'added extraction script %s', get_script) rule.backend.script_get = lua_redis.add_redis_script(get_script, our_redis_params) @@ -977,7 +1002,7 @@ local function reputation_redis_init(rule, cfg, ev_base, worker) ]] local set_script = lua_util.jinja_template(redis_adaptive_emea_script_tpl, - {windows = rule.backend.config.buckets}) + { windows = rule.backend.config.buckets }) rspamd_logger.debugm(N, rspamd_config, 'added emea update script %s', set_script) rule.backend.script_set = lua_redis.add_redis_script(set_script, our_redis_params) @@ -998,25 +1023,25 @@ local function reputation_redis_get_token(task, rule, prefix, token, continuatio continuation_cb(nil, key, data) else rspamd_logger.errx(task, 'rule %s - invalid type while getting reputation keys %s: %s', - rule['symbol'], key, type(data)) + rule['symbol'], key, type(data)) continuation_cb("invalid type", key, nil) end elseif err then rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s', - rule['symbol'], key, err) + rule['symbol'], key, err) continuation_cb(err, key, nil) else rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s', - rule['symbol'], key, "unknown error") + rule['symbol'], key, "unknown error") continuation_cb("unknown error", key, nil) end end local ret = lua_redis.exec_redis_script(rule.backend.script_get, - {task = task, is_write = false}, + { task = task, is_write = false }, redis_get_cb, - {key}) + { key }) if not ret then rspamd_logger.errx(task, 'cannot make redis request to check results') end @@ -1031,7 +1056,7 @@ local function reputation_redis_set_token(task, rule, prefix, token, sc, continu local function redis_set_cb(err, data) if err then rspamd_logger.errx(task, 'rule %s - got error while setting reputation keys %s: %s', - rule['symbol'], key, err) + rule['symbol'], key, err) if continuation_cb then continuation_cb(err, key) end @@ -1045,11 +1070,11 @@ local function reputation_redis_set_token(task, rule, prefix, token, sc, continu lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s', rule['symbol'], key, sc) local ret = lua_redis.exec_redis_script(rule.backend.script_set, - {task = task, is_write = true}, + { task = task, is_write = true }, redis_set_cb, - {key, tostring(os.time() * 1000), - tostring(sc), - tostring(rule.backend.config.expiry)}) + { key, tostring(os.time() * 1000), + tostring(sc), + tostring(rule.backend.config.expiry) }) if not ret then rspamd_logger.errx(task, 'got error while connecting to redis') end @@ -1067,7 +1092,7 @@ local backends = { schema = lua_redis.generate_schema({ prefix = ts.string, expiry = ts.number + ts.string / lua_util.parse_time_interval, - buckets = ts.array_of(ts.shape{ + buckets = ts.array_of(ts.shape { time = ts.number + ts.string / lua_util.parse_time_interval, name = ts.string, mult = ts.number + ts.string / tonumber @@ -1089,7 +1114,7 @@ local backends = { set_token = reputation_redis_set_token, }, dns = { - schema = ts.shape{ + schema = ts.shape { list = ts.string, }, config = { @@ -1151,22 +1176,22 @@ local function callback_gen(cb, rule) end local function parse_rule(name, tbl) - local sel_type,sel_conf = fun.head(tbl.selector) + local sel_type, sel_conf = fun.head(tbl.selector) local selector = selectors[sel_type] if not selector then rspamd_logger.errx(rspamd_config, "unknown selector defined for rule %s: %s", name, sel_type) - return + return false end - local bk_type,bk_conf = fun.head(tbl.backend) + local bk_type, bk_conf = fun.head(tbl.backend) local backend = backends[bk_type] if not backend then rspamd_logger.errx(rspamd_config, "unknown backend defined for rule %s: %s", name, - tbl.backend.type) - return + tbl.backend.type) + return false end -- Allow config override local rule = { @@ -1178,12 +1203,12 @@ local function parse_rule(name, tbl) -- Override default config params rule.backend.config = lua_util.override_defaults(rule.backend.config, bk_conf) if backend.schema then - local checked,schema_err = backend.schema:transform(rule.backend.config) + local checked, schema_err = backend.schema:transform(rule.backend.config) if not checked then rspamd_logger.errx(rspamd_config, "cannot parse backend config for %s: %s", sel_type, schema_err) - return + return false end rule.backend.config = checked @@ -1191,7 +1216,7 @@ local function parse_rule(name, tbl) rule.selector.config = lua_util.override_defaults(rule.selector.config, sel_conf) if selector.schema then - local checked,schema_err = selector.schema:transform(rule.selector.config) + local checked, schema_err = selector.schema:transform(rule.selector.config) if not checked then rspamd_logger.errx(rspamd_config, "cannot parse selector config for %s: %s (%s)", @@ -1228,7 +1253,9 @@ local function parse_rule(name, tbl) -- Hack: we assume that it is an ip whitelist :( local ip = task:get_from_ip() - if ip and map:get_key(ip) then return true end + if ip and map:get_key(ip) then + return true + end return false end } @@ -1287,20 +1314,20 @@ local function parse_rule(name, tbl) rule_type = 'callback' end - local id = rspamd_config:register_symbol{ + local id = rspamd_config:register_symbol { name = rule.symbol, type = rule_type, callback = callback_gen(reputation_filter_cb, rule), - augmentations = {string.format("timeout=%f", redis_params.timeout or 0.0)}, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, } if rule.selector.config.split_symbols then - rspamd_config:register_symbol{ + rspamd_config:register_symbol { name = rule.symbol .. '_HAM', type = 'virtual', parent = id, } - rspamd_config:register_symbol{ + rspamd_config:register_symbol { name = rule.symbol .. '_SPAM', type = 'virtual', parent = id, @@ -1315,23 +1342,23 @@ local function parse_rule(name, tbl) if rule.selector.postfilter then -- Also register a postfilter - rspamd_config:register_symbol{ + rspamd_config:register_symbol { name = rule.symbol .. '_POST', type = 'postfilter', flags = 'nostat,explicit_disable,ignore_passthrough', callback = callback_gen(reputation_postfilter_cb, rule), - augmentations = {string.format("timeout=%f", redis_params.timeout or 0.0)}, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, } end if rule.selector.idempotent then -- Has also idempotent component (e.g. saving data to the backend) - rspamd_config:register_symbol{ + rspamd_config:register_symbol { name = rule.symbol .. '_IDEMPOTENT', type = 'idempotent', flags = 'explicit_disable,ignore_passthrough', callback = callback_gen(reputation_idempotent_cb, rule), - augmentations = {string.format("timeout=%f", redis_params.timeout or 0.0)}, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, } end @@ -1342,12 +1369,12 @@ local opts = rspamd_config:get_all_opt("reputation") -- Initialization part if not (opts and type(opts) == 'table') then - rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + rspamd_logger.infox(rspamd_config, 'Module is not configured, disabling it') return end if opts['rules'] then - for k,v in pairs(opts['rules']) do + for k, v in pairs(opts['rules']) do if not ((v or E).selector) then rspamd_logger.errx(rspamd_config, "no selector defined for rule %s", k) else |