- --[[
- Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ]]--
-
- if confighelp then
- return
- end
-
- local rspamd_logger = require "rspamd_logger"
- local rspamd_util = require "rspamd_util"
- local fun = require "fun"
- local lua_util = require "lua_util"
-
- local N = "whitelist"
-
- local options = {
- dmarc_allow_symbol = 'DMARC_POLICY_ALLOW',
- spf_allow_symbol = 'R_SPF_ALLOW',
- dkim_allow_symbol = 'R_DKIM_ALLOW',
- check_local = false,
- check_authed = false,
- rules = {}
- }
-
- local E = {}
-
- local function whitelist_cb(symbol, rule, task)
-
- local domains = {}
-
- local function find_domain(dom, check)
- local mult
- local how = 'wl'
-
- -- Can be overridden
- if rule.blacklist then how = 'bl' end
-
- local function parse_val(val)
- local how_override
- -- Strict is 'special'
- if rule.strict then how_override = 'both' end
- if val then
- lua_util.debugm(N, task, "found whitelist key: %s=%s", dom, val)
- if val == '' then
- return (how_override or how),1.0
- elseif val:match('^bl:') then
- return (how_override or 'bl'),(tonumber(val:sub(4)) or 1.0)
- elseif val:match('^wl:') then
- return (how_override or 'wl'),(tonumber(val:sub(4)) or 1.0)
- elseif val:match('^both:') then
- return (how_override or 'both'),(tonumber(val:sub(6)) or 1.0)
- else
- return (how_override or how),(tonumber(val) or 1.0)
- end
- end
-
- return (how_override or how),1.0
- end
-
- if rule['map'] then
- local val = rule['map']:get_key(dom)
- if val then
- how,mult = parse_val(val)
-
- if not domains[check] then
- domains[check] = {}
- end
-
- domains[check] = {
- [dom] = {how, mult}
- }
-
- lua_util.debugm(N, task, "final result: %s: %s->%s",
- dom, how, mult)
- return true,mult,how
- end
- elseif rule['maps'] then
- for _,v in pairs(rule['maps']) do
- local map = v.map
- if map then
- local val = map:get_key(dom)
- if val then
- how,mult = parse_val(val)
-
- if not domains[check] then
- domains[check] = {}
- end
-
- domains[check] = {
- [dom] = {how, mult}
- }
-
- lua_util.debugm(N, task, "final result: %s: %s->%s",
- dom, how, mult)
- return true,mult,how
- end
- end
- end
- else
- mult = rule['domains'][dom]
- if mult then
- if not domains[check] then
- domains[check] = {}
- end
-
- domains[check] = {
- [dom] = {how, mult}
- }
-
- return true, mult,how
- end
- end
-
- return false,0.0,how
- end
-
- local spf_violated = false
- local dmarc_violated = false
- local dkim_violated = false
- local ip_addr = task:get_ip()
-
- if rule.valid_spf then
- if not task:has_symbol(options['spf_allow_symbol']) then
- -- Not whitelisted
- spf_violated = true
- end
- -- Now we can check from domain or helo
- local from = task:get_from(1)
-
- if ((from or E)[1] or E).domain then
- local tld = rspamd_util.get_tld(from[1]['domain'])
-
- if tld then
- find_domain(tld, 'spf')
- end
- else
- local helo = task:get_helo()
-
- if helo then
- local tld = rspamd_util.get_tld(helo)
-
- if tld then
- find_domain(tld, 'spf')
- end
- end
- end
- end
-
- if rule.valid_dkim then
- if task:has_symbol('DKIM_TRACE') then
- local sym = task:get_symbol('DKIM_TRACE')
- local dkim_opts = sym[1]['options']
- if dkim_opts then
- fun.each(function(val)
- if val[2] == '+' then
- local tld = rspamd_util.get_tld(val[1])
- find_domain(tld, 'dkim_success')
- elseif val[2] == '-' then
- local tld = rspamd_util.get_tld(val[1])
- find_domain(tld, 'dkim_fail')
- end
- end,
- fun.map(function(s)
- return lua_util.rspamd_str_split(s, ':')
- end, dkim_opts))
- end
- end
- end
-
- if rule.valid_dmarc then
- if not task:has_symbol(options.dmarc_allow_symbol) then
- dmarc_violated = true
- end
-
- local from = task:get_from(2)
-
- if ((from or E)[1] or E).domain then
- local tld = rspamd_util.get_tld(from[1]['domain'])
-
- if tld then
- local found = find_domain(tld, 'dmarc')
- if not found then
- find_domain(from[1]['domain'], 'dmarc')
- end
- end
- end
- end
-
-
- local final_mult = 1.0
- local found_wl, found_bl = false, false
- local opts = {}
-
- if rule.valid_dkim then
- dkim_violated = true
-
- for dom,val in pairs(domains.dkim_success or E) do
- if val[1] == 'wl' or val[1] == 'both' then
- -- We have valid and whitelisted signature
- table.insert(opts, dom .. ':d:+')
- found_wl = true
- dkim_violated = false
-
- if not found_bl then
- final_mult = val[2]
- end
- end
- end
-
- -- Blacklist counterpart
- for dom,val in pairs(domains.dkim_fail or E) do
- if val[1] == 'bl' or val[1] == 'both' then
- -- We have valid and whitelisted signature
- table.insert(opts, dom .. ':d:-')
- found_bl = true
- final_mult = val[2]
- else
- -- Even in the case of whitelisting we need to indicate dkim failure
- dkim_violated = true
- end
- end
- end
-
- local function check_domain_violation(what, dom, val, violated)
- if violated then
- if val[1] == 'both' or val[1] == 'bl' then
- found_bl = true
- final_mult = val[2]
- table.insert(opts, string.format("%s:%s:-", dom, what))
- end
- else
- if val[1] == 'both' or val[1] == 'wl' then
- found_wl = true
- table.insert(opts, string.format("%s:%s:+", dom, what))
- if not found_bl then
- final_mult = val[2]
- end
- end
- end
- end
-
- if rule.valid_dmarc then
-
- found_wl = false
-
- for dom,val in pairs(domains.dmarc or E) do
- check_domain_violation('D', dom, val,
- (dmarc_violated or dkim_violated))
- end
- end
-
- if rule.valid_spf then
- found_wl = false
-
- for dom,val in pairs(domains.spf or E) do
- check_domain_violation('s', dom, val,
- (spf_violated or dkim_violated))
- end
- end
-
- lua_util.debugm(N, task, "final mult: %s", final_mult)
-
- local function add_symbol(violated, mult)
- local sym = symbol
-
- if violated then
- if rule.inverse_symbol then
- sym = rule.inverse_symbol
- elseif not rule.blacklist then
- mult = -mult
- end
-
- if rule.inverse_multiplier then
- mult = mult * rule.inverse_multiplier
- end
-
- task:insert_result(sym, mult, opts)
- else
- task:insert_result(sym, mult, opts)
- end
- end
-
- if found_bl then
- if not ((not options.check_authed and task:get_user()) or
- (not options.check_local and ip_addr and ip_addr:is_local())) then
- add_symbol(true, final_mult)
- else
- if rule.valid_spf or rule.valid_dmarc then
- rspamd_logger.infox(task, "skip DMARC/SPF blacklists for local networks and/or authorized users")
- else
- add_symbol(true, final_mult)
- end
- end
- elseif found_wl then
- add_symbol(false, final_mult)
- end
-
- end
-
- local function gen_whitelist_cb(symbol, rule)
- return function(task)
- whitelist_cb(symbol, rule, task)
- end
- end
-
- local configure_whitelist_module = function()
- local opts = rspamd_config:get_all_opt('whitelist')
- if opts then
- for k,v in pairs(opts) do
- options[k] = v
- end
-
- local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N,
- false, false)
- options.check_local = auth_and_local_conf[1]
- options.check_authed = auth_and_local_conf[2]
- else
- rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
- return
- end
-
- if options['rules'] then
- fun.each(function(symbol, rule)
- if rule['domains'] then
- if type(rule['domains']) == 'string' then
- rule['map'] = rspamd_config:add_map{
- url = rule['domains'],
- description = "Whitelist map for " .. symbol,
- type = 'map'
- }
- elseif type(rule['domains']) == 'table' then
- -- Transform ['domain1', 'domain2' ...] to indexes:
- -- {'domain1' = 1, 'domain2' = 1 ...]
- local is_domains_list = fun.all(function(v)
- if type(v) == 'table' then
- return true
- elseif type(v) == 'string' and not (string.match(v, '^https?://') or
- string.match(v, '^ftp://') or string.match(v, '^[./]')) then
- return true
- end
-
- return false
- end, rule.domains)
-
- if is_domains_list then
- rule['domains'] = fun.tomap(fun.map(function(d)
- if type(d) == 'table' then
- return d[1],d[2]
- end
-
- return d,1.0
- end, rule['domains']))
- else
- rule['map'] = rspamd_config:add_map{
- url = rule['domains'],
- description = "Whitelist map for " .. symbol,
- type = 'map'
- }
- end
- else
- rspamd_logger.errx(rspamd_config, 'whitelist %s has bad "domains" value',
- symbol)
- return
- end
-
- local flags = 'nice,empty'
- if rule['blacklist'] then
- flags = 'empty'
- end
-
- local id = rspamd_config:register_symbol({
- name = symbol,
- flags = flags,
- callback = gen_whitelist_cb(symbol, rule),
- score = rule.score or 0,
- })
-
- if rule.inverse_symbol then
- rspamd_config:register_symbol({
- name = rule.inverse_symbol,
- type = 'virtual',
- parent = id,
- score = rule.score and -(rule.score) or 0,
- })
- end
-
- local spf_dep = false
- local dkim_dep = false
- if rule['valid_spf'] then
- rspamd_config:register_dependency(symbol, options['spf_allow_symbol'])
- spf_dep = true
- end
- if rule['valid_dkim'] then
- rspamd_config:register_dependency(symbol, options['dkim_allow_symbol'])
- dkim_dep = true
- end
- if rule['valid_dmarc'] then
- if not spf_dep then
- rspamd_config:register_dependency(symbol, options['spf_allow_symbol'])
- end
- if not dkim_dep then
- rspamd_config:register_dependency(symbol, options['dkim_allow_symbol'])
- end
- rspamd_config:register_dependency(symbol, 'DMARC_CALLBACK')
- end
-
- if rule['score'] then
- if not rule['group'] then
- rule['group'] = 'whitelist'
- end
- rule['name'] = symbol
- rspamd_config:set_metric_symbol(rule)
-
- if rule.inverse_symbol then
- local inv_rule = lua_util.shallowcopy(rule)
- inv_rule.name = rule.inverse_symbol
- inv_rule.score = -rule.score
- rspamd_config:set_metric_symbol(inv_rule)
- end
- end
- end
- end, options['rules'])
- else
- lua_util.disable_module(N, "config")
- end
- end
-
- configure_whitelist_module()
|