]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Rework and fix whitelist plugin
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 1 Oct 2018 17:16:02 +0000 (18:16 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 1 Oct 2018 17:16:02 +0000 (18:16 +0100)
src/plugins/lua/whitelist.lua

index f810433c435d549fc2da6f93cd5fd9330fbf20ab..7e6052d2416bb0340052ee385e081bf9561935e2 100644 (file)
@@ -39,34 +39,50 @@ local function whitelist_cb(symbol, rule, task)
 
   local domains = {}
 
-  local function find_domain(dom)
+  local function find_domain(dom, check)
     local mult
     local how = 'wl'
+
+    -- Can be overriden
     if rule.blacklist then how = 'bl' end
 
     local function parse_val(val)
+      local how_override
+      -- Strict is 'special'
+      if rule.strict then how_override = 'both' end
       if val then
+        lua_util.debugm(N, task, "found whitelist key: %s=%s", dom, val)
         if val == '' then
-          return how,1.0
+          return (how_override or how),1.0
         elseif val:match('^bl:') then
-          return 'bl',(tonumber(val:sub(4)) or 1.0)
+          return (how_override or 'bl'),(tonumber(val:sub(4)) or 1.0)
         elseif val:match('^wl:') then
-          return 'wl',(tonumber(val:sub(4)) or 1.0)
+          return (how_override or 'wl'),(tonumber(val:sub(4)) or 1.0)
         elseif val:match('^both:') then
-          return 'both',(tonumber(val:sub(6)) or 1.0)
+          return (how_override or 'both'),(tonumber(val:sub(6)) or 1.0)
         else
-          return how,(tonumber(val) or 1.0)
+          return (how_override or how),(tonumber(val) or 1.0)
         end
       end
 
-      return how,1.0
+      return (how_override or how),1.0
     end
 
     if rule['map'] then
       local val = rule['map']:get_key(dom)
       if val then
         how,mult = parse_val(val)
-        table.insert(domains, dom)
+
+        if not domains[check] then
+          domains[check] = {}
+        end
+
+        domains[check] = {
+          [dom] = {how, mult}
+        }
+
+        lua_util.debugm(N, task, "final result: %s: %s->%s",
+            dom, how, mult)
         return true,mult,how
       end
     elseif rule['maps'] then
@@ -76,7 +92,17 @@ local function whitelist_cb(symbol, rule, task)
           local val = map:get_key(dom)
           if val then
             how,mult = parse_val(val)
-            table.insert(domains, dom)
+
+            if not domains[check] then
+              domains[check] = {}
+            end
+
+            domains[check] = {
+              [dom] = {how, mult}
+            }
+
+            lua_util.debugm(N, task, "final result: %s: %s->%s",
+                dom, how, mult)
             return true,mult,how
           end
         end
@@ -84,7 +110,14 @@ local function whitelist_cb(symbol, rule, task)
     else
       mult = rule['domains'][dom]
       if mult then
-        table.insert(domains, dom)
+        if not domains[check] then
+          domains[check] = {}
+        end
+
+        domains[check] = {
+          [dom] = {how, mult}
+        }
+
         return true, mult,how
       end
     end
@@ -92,23 +125,14 @@ local function whitelist_cb(symbol, rule, task)
     return false,0.0,how
   end
 
-  local found = false
-  local mult
-  local how = 'wl' -- whitelist only
   local spf_violated = false
-  local dkim_violated = false
   local dmarc_violated = false
 
   if rule['valid_spf'] then
     if not task:has_symbol(options['spf_allow_symbol']) then
       -- Not whitelisted
-      if not rule['blacklist'] and not rule['strict'] then
-        return
-      end
-
       spf_violated = true
     end
-
     -- Now we can check from domain or helo
     local from = task:get_from(1)
 
@@ -116,7 +140,7 @@ local function whitelist_cb(symbol, rule, task)
       local tld = rspamd_util.get_tld(from[1]['domain'])
 
       if tld then
-        found, mult, how = find_domain(tld)
+        find_domain(tld, 'spf')
       end
     else
       local helo = task:get_helo()
@@ -125,46 +149,33 @@ local function whitelist_cb(symbol, rule, task)
         local tld = rspamd_util.get_tld(helo)
 
         if tld then
-          found, mult, how = find_domain(tld)
+          find_domain(tld)
         end
       end
     end
   end
 
   if rule['valid_dkim'] then
-    local sym = task:get_symbol(options['dkim_allow_symbol'])
-    if not sym then
-      if not rule['blacklist'] and not rule['strict'] then
-        return
-      end
-
-      dkim_violated = true
-    else
-      found = false
+    if task:has_symbol('DKIM_TRACE') then
+      local sym = task:get_symbol('DKIM_TRACE')
       local dkim_opts = sym[1]['options']
       if dkim_opts then
         fun.each(function(val)
-          if not found then
-            local tld = rspamd_util.get_tld(val)
-
-            if tld then
-              found, mult, how = find_domain(tld)
-              if not found then
-                found, mult, how = find_domain(val)
-              end
+            if val[2] == '+' then
+              find_domain(val[1], 'dkim_success')
+            elseif val[2] == '-' then
+              find_domain(val[1], 'dkim_fail')
             end
-          end
-        end, dkim_opts)
+          end,
+            fun.map(function(s)
+              return lua_util.rspamd_str_split(s, ':')
+            end, dkim_opts))
       end
     end
   end
 
   if rule['valid_dmarc'] then
     if not task:has_symbol(options['dmarc_allow_symbol']) then
-      if not rule['blacklist'] and not rule['strict'] then
-        return
-      end
-
       dmarc_violated = true
     end
 
@@ -174,40 +185,107 @@ local function whitelist_cb(symbol, rule, task)
       local tld = rspamd_util.get_tld(from[1]['domain'])
 
       if tld then
-        found, mult, how = find_domain(tld)
+        local found = find_domain(tld, 'dmarc')
         if not found then
-          found, mult, how = find_domain(from[1]['domain'])
+          find_domain(from[1]['domain'], 'dmarc')
         end
       end
     end
   end
 
-  if found then
-    local function add_symbol(violated)
-      local sym = symbol
 
-      if violated then
-        if rule.inverse_symbol then
-          sym = rule.inverse_symbol
-        else
-          -- Inverse multiplier
-          if not rule.blacklist then
-            mult = -mult
-          end
-        end
+  local final_mult = 1.0
+  local found_wl, found_bl = false, false
+  local opts = {}
+
+  if rule.valid_dkim then
+    for dom,val in pairs(domains.dkim_success or E) do
+      if val[1] == 'wl' or val[1] == 'both' then
+        -- We have valid and whitelisted signature
+        table.insert(opts, dom .. ':d:+')
+        found_wl = true
 
-        if rule.strict or how == 'bl' or how == 'both' then
-          -- Insert violation rule
-          task:insert_result(sym, mult, domains)
+        if not found_bl then
+          final_mult = val[2]
+          lua_util.debugm(N, task, "hui4 final mult: %s", final_mult)
         end
-      else
-        if how == 'wl' or how == 'both' then
-          task:insert_result(sym, mult, domains)
+      end
+    end
+
+    -- Blacklist counterpart
+    for dom,val in pairs(domains.dkim_fail or E) do
+      if val[1] == 'bl' or val[1] == 'both' then
+        -- We have valid and whitelisted signature
+        table.insert(opts, dom .. ':d:-')
+        found_bl = true
+        final_mult = val[2]
+        lua_util.debugm(N, task, "hui2 final mult: %s", final_mult)
+      end
+    end
+  end
+
+  local function check_domain_violation(what, dom, val, violated)
+    if violated then
+      if val[1] == 'both' or val[1] == 'bl' then
+        found_bl = true
+        final_mult = val[2]
+        lua_util.debugm(N, task, "hui3 final mult: %s", final_mult)
+        table.insert(opts, string.format("%s:%s:-", dom, what))
+      end
+    else
+      if val[1] == 'both' or val[1] == 'wl' then
+        found_wl = true
+        table.insert(opts, string.format("%s:%s:+", dom, what))
+        if not found_bl then
+          final_mult = val[2]
+          lua_util.debugm(N, task, "hui1 final mult: %s", final_mult)
         end
       end
     end
+  end
+
+  if rule.valid_dmarc then
+    found_wl = false
 
-    add_symbol(dmarc_violated or dkim_violated or spf_violated)
+    for dom,val in pairs(domains.dmarc or E) do
+      check_domain_violation('D', dom, val, dmarc_violated)
+    end
+  end
+
+  if rule.valid_spf then
+    found_wl = false
+
+    for dom,val in pairs(domains.spf or E) do
+      check_domain_violation('s', dom, val, spf_violated)
+    end
+  end
+
+  lua_util.debugm(N, task, "final mult: %s", final_mult)
+
+  local function add_symbol(violated, mult)
+    local sym = symbol
+
+    if violated then
+      if rule.inverse_symbol then
+        sym = rule.inverse_symbol
+      elseif not rule.blacklist then
+        mult = -mult
+      end
+
+      if rule.inverse_multiplier then
+        mult = mult * rule.inverse_multiplier
+      end
+
+      task:insert_result(sym, mult, opts)
+    else
+      task:insert_result(sym, mult, opts)
+    end
+  end
+
+  if found_bl then
+    add_symbol(true, final_mult)
+  elseif found_wl then
+    add_symbol(false, final_mult)
   end
 
 end
@@ -318,6 +396,13 @@ local configure_whitelist_module = function()
           end
           rule['name'] = symbol
           rspamd_config:set_metric_symbol(rule)
+
+          if rule.inverse_symbol then
+            local inv_rule = lua_util.shallowcopy(rule)
+            inv_rule.name = rule.inverse_symbol
+            inv_rule.score = -rule.score
+            rspamd_config:set_metric_symbol(inv_rule)
+          end
         end
       end
     end, options['rules'])