]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Rbl: Major whitelisting logic rework
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 27 Aug 2019 15:53:35 +0000 (16:53 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 27 Aug 2019 15:54:45 +0000 (16:54 +0100)
src/plugins/lua/rbl.lua

index 054b5e45b28438b3f46905215665b5af77ed80c5..f38cf69814157fc47cd4dc2650b5e104bf261d8b 100644 (file)
@@ -40,6 +40,7 @@ local local_exclusions
 local white_symbols = {}
 local black_symbols = {}
 local monitored_addresses = {}
+local known_selectors = {} -- map from selector string to selector id
 
 local function get_monitored(rbl)
   local default_monitored = '1.0.0.127'
@@ -179,12 +180,24 @@ local function gen_check_rcvd_conditions(rbl, received_total)
   end
 end
 
-local function rbl_dns_process(task, rbl, to_resolve, results, err, orig)
+local function rbl_dns_process(task, rbl, to_resolve, results, err, resolve_table_elt)
+  local function make_option(ip)
+    if ip then
+      return string.format('%s:%s:%s',
+          resolve_table_elt.orig,
+          resolve_table_elt.what,
+          ip)
+    else
+      return string.format('%s:%s',
+          resolve_table_elt.orig,
+          resolve_table_elt.what)
+    end
+  end
   if err and (err ~= 'requested record is not found' and
       err ~= 'no records with this name') then
     rspamd_logger.infox(task, 'error looking up %s: %s', to_resolve, err)
     task:insert_result(rbl.symbol .. '_FAIL', 1, string.format('%s:%s',
-        orig, err))
+        resolve_table_elt.orig, err))
     return
   end
 
@@ -200,7 +213,7 @@ local function rbl_dns_process(task, rbl, to_resolve, results, err, orig)
   end
 
   if rbl.returncodes == nil and rbl.returnbits == nil and rbl.symbol ~= nil then
-    task:insert_result(rbl.symbol, 1, orig)
+    task:insert_result(rbl.symbol, 1, make_option())
     return
   end
 
@@ -215,7 +228,7 @@ local function rbl_dns_process(task, rbl, to_resolve, results, err, orig)
         for _,check_bit in ipairs(bits) do
           if bit.band(ipnum, check_bit) == check_bit then
             foundrc = true
-            task:insert_result(s, 1, orig .. ' : ' .. ipstr)
+            task:insert_result(s, 1, make_option())
             -- Here, we continue with other bits
           end
         end
@@ -225,7 +238,7 @@ local function rbl_dns_process(task, rbl, to_resolve, results, err, orig)
         for _,v in ipairs(codes) do
           if string.find(ipstr, '^' .. v .. '$') then
             foundrc = true
-            task:insert_result(s, 1, orig .. ' : ' .. ipstr)
+            task:insert_result(s, 1, make_option())
             break
           end
         end
@@ -234,7 +247,7 @@ local function rbl_dns_process(task, rbl, to_resolve, results, err, orig)
 
     if not foundrc then
       if rbl.unknown and rbl.symbol then
-        task:insert_result(rbl.symbol, 1, orig)
+        task:insert_result(rbl.symbol, 1, make_option(ipstr))
       else
         rspamd_logger.errx(task, 'RBL %1 returned unknown result: %2',
             rbl.rbl, ipstr)
@@ -245,20 +258,43 @@ local function rbl_dns_process(task, rbl, to_resolve, results, err, orig)
 end
 
 local function gen_rbl_callback(rule)
-
-  local function add_dns_request(task, req, forced, is_ip, requests_table)
+  local function is_whitelisted(task, req, req_str, whitelist, what)
     if rule.whitelist then
       if rule.whitelist:get_key(req) then
-        lua_util.debugm(N, task, 'whitelisted %s on %s', req, rule.symbol)
+        lua_util.debugm(N, task,
+            'whitelisted %s on %s',
+            req_str, rule.symbol)
 
-        return
+        return true
       end
     end
 
+    -- Maybe whitelisted by some other rbl rule
+    if whitelist then
+      local wl_what = whitelist[req_str]
+      if wl_what then
+        lua_util.debugm(N, task,
+            'whitelisted %s on %s by %s rbl rule (%s checked)',
+            req_str, wl_what, what)
+        return wl_what == what
+      end
+    end
+
+    return false
+  end
+
+  local function add_dns_request(task, req, forced, is_ip, requests_table, what, whitelist)
+    local req_str = req
     if is_ip then
-      req = ip_to_rbl(req)
+      req_str = ip_to_rbl(req)
+    end
+
+    if is_whitelisted(task, req, req_str, whitelist, what) then
+      return
     end
 
+    req = req_str
+
     if requests_table[req] then
       -- Duplicate request
       if forced and not requests_table[req].forced then
@@ -274,7 +310,8 @@ local function gen_rbl_callback(rule)
             forced = forced,
             n = processed,
             orig = req,
-            resolve_ip = resolve_ip
+            resolve_ip = resolve_ip,
+            what = what,
           }
           requests_table[req] = nreq
         end
@@ -296,7 +333,8 @@ local function gen_rbl_callback(rule)
           forced = forced,
           n = to_resolve,
           orig = orign,
-          is_ip = resolve_ip
+          is_ip = resolve_ip,
+          what = what,
         }
         requests_table[req] = nreq
       end
@@ -340,17 +378,18 @@ local function gen_rbl_callback(rule)
     return true
   end
 
-  local function check_helo(task, requests_table)
+  local function check_helo(task, requests_table, whitelist)
     local helo = task:get_helo()
 
     if not helo then
       return false
     end
 
-    add_dns_request(task, helo, true, false, requests_table)
+    add_dns_request(task, helo, true, false, requests_table,
+        'helo', whitelist)
   end
 
-  local function check_dkim(task, requests_table)
+  local function check_dkim(task, requests_table, whitelist)
     local das = task:get_symbol('DKIM_TRACE')
     local mime_from_domain
     local ret = false
@@ -380,16 +419,18 @@ local function gen_rbl_callback(rule)
             end
 
             if mime_from_domain and mime_from_domain == domain_tld then
-              add_dns_request(task, domain_tld, true, false, requests_table)
+              add_dns_request(task, domain_tld, true, false, requests_table,
+              'dkim', whitelist)
               ret = true
             end
           else
             if rule.dkim_domainonly then
               add_dns_request(task, rspamd_util.get_tld(domain),
-                  false, false, requests_table)
+                  false, false, requests_table, 'dkim', whitelist)
               ret = true
             else
-              add_dns_request(task, domain, false, false, requests_table)
+              add_dns_request(task, domain, false, false, requests_table,
+                  'dkim', whitelist)
               ret = true
             end
           end
@@ -400,7 +441,7 @@ local function gen_rbl_callback(rule)
     return ret
   end
 
-  local function check_emails(task, requests_table)
+  local function check_emails(task, requests_table, whitelist)
     local ex_params = {
       task = task,
       limit = rule.requests_limit,
@@ -421,16 +462,19 @@ local function gen_rbl_callback(rule)
 
     for _,email in ipairs(emails) do
       if rule.emails_domainonly then
-        add_dns_request(task, email:get_tld(), false, false, requests_table)
+        add_dns_request(task, email:get_tld(), false, false, requests_table,
+            'email', whitelist)
       else
         if rule.hash then
           -- Leave @ as is
           add_dns_request(task, string.format('%s@%s',
-              email:get_user(), email:get_host()), false, false, requests_table)
+              email:get_user(), email:get_host()), false, false,
+              requests_table, 'email', whitelist)
         else
           -- Replace @ with .
           add_dns_request(task, string.format('%s.%s',
-              email:get_user(), email:get_host()), false, false, requests_table)
+              email:get_user(), email:get_host()), false, false,
+              requests_table, 'email', whitelist)
         end
       end
     end
@@ -438,7 +482,7 @@ local function gen_rbl_callback(rule)
     return true
   end
 
-  local function check_urls(task, requests_table)
+  local function check_urls(task, requests_table, whitelist)
     local ex_params = {
       task = task,
       limit = rule.requests_limit,
@@ -455,13 +499,14 @@ local function gen_rbl_callback(rule)
     end
 
     for _,u in ipairs(urls) do
-      add_dns_request(task, u:get_tld(), false, false, requests_table)
+      add_dns_request(task, u:get_tld(), false,
+          false, requests_table, 'url', whitelist)
     end
 
     return true
   end
 
-  local function check_from(task, requests_table)
+  local function check_from(task, requests_table, whitelist)
     local ip = task:get_from_ip()
 
     if not ip or not ip:is_valid() then
@@ -469,13 +514,15 @@ local function gen_rbl_callback(rule)
     end
     if (ip:get_version() == 6 and rule.ipv6) or
         (ip:get_version() == 4 and rule.ipv4) then
-      add_dns_request(task, ip, true, true, requests_table)
+      add_dns_request(task, ip, true, true,
+          requests_table, 'from_ip',
+          whitelist)
     end
 
     return true
   end
 
-  local function check_received(task, requests_table)
+  local function check_received(task, requests_table, whitelist)
     local received = fun.filter(function(h)
       return not h['flags']['artificial']
     end, task:get_received_headers()):totable()
@@ -485,30 +532,34 @@ local function gen_rbl_callback(rule)
 
     for pos,rh in ipairs(received) do
       if check_conditions(rh, pos) then
-        add_dns_request(task, rh.real_ip, false, true, requests_table)
+        add_dns_request(task, rh.real_ip, false, true,
+            requests_table, 'rcvd',
+            whitelist)
       end
     end
 
     return true
   end
 
-  local function check_rdns(task, requests_table)
+  local function check_rdns(task, requests_table, whitelist)
     local hostname = task:get_hostname()
     if hostname == nil or hostname == 'unknown' then
       return false
     end
 
-    add_dns_request(task, hostname, true, false, requests_table)
+    add_dns_request(task, hostname, true, false,
+        requests_table, 'rdns', whitelist)
 
     return true
   end
 
-  local function check_selector(task, requests_table)
+  local function check_selector(task, requests_table, whitelist)
     local res = rule.selector(task)
 
     if res then
       for _,r in ipairs(res) do
-        add_dns_request(task, r, false, false, requests_table)
+        add_dns_request(task, r, false, false, requests_table,
+            'sel' .. rule.selector_id, whitelist)
       end
     end
   end
@@ -561,17 +612,19 @@ local function gen_rbl_callback(rule)
   return function(task)
     -- DNS requests to issue (might be hashed afterwards)
     local dns_req = {}
+    local whitelist = task:cache_get('rbl_whitelisted') or {}
 
-    local function gen_rbl_dns_callback(orig)
+    local function gen_rbl_dns_callback(resolve_table_elt)
       return function(_, to_resolve, results, err)
-        rbl_dns_process(task, rule, to_resolve, results, err, orig)
+        rbl_dns_process(task, rule, to_resolve, results, err, resolve_table_elt)
       end
     end
 
     -- Execute functions pipeline
     for _,f in ipairs(pipeline) do
-      if not f(task, dns_req) then
-        lua_util.debugm(N, task, "skip rbl check: %s; pipeline condition returned false",
+      if not f(task, dns_req, whitelist) then
+        lua_util.debugm(N, task,
+            "skip rbl check: %s; pipeline condition returned false",
             rule.symbol)
         return
       end
@@ -584,7 +637,7 @@ local function gen_rbl_callback(rule)
     local nresolved = 0
 
     -- This is called when doing resolve_ip phase...
-    local function gen_rbl_ip_dns_callback(orig)
+    local function gen_rbl_ip_dns_callback(_)
       return function(_, _, results, err)
         if not err then
           for _,dns_res in ipairs(results) do
@@ -608,7 +661,7 @@ local function gen_rbl_callback(rule)
               r:resolve_a({
                 task = task,
                 name = req.n,
-                callback = gen_rbl_dns_callback(orig),
+                callback = gen_rbl_dns_callback(req),
                 forced = req.forced
               })
             else
@@ -639,7 +692,7 @@ local function gen_rbl_callback(rule)
           if r:resolve('aaaa', {
             task = task,
             name = req.n,
-            callback = gen_rbl_ip_dns_callback(req.orig),
+            callback = gen_rbl_ip_dns_callback(req),
             forced = req.forced
           }) then
             nresolved = nresolved + 1
@@ -648,7 +701,7 @@ local function gen_rbl_callback(rule)
           r:resolve_a({
             task = task,
             name = req.n,
-            callback = gen_rbl_dns_callback(req.orig),
+            callback = gen_rbl_dns_callback(req),
             forced = req.forced
           })
         end
@@ -678,15 +731,28 @@ local function add_rbl(key, rbl)
   end
 
   if rbl.selector then
-    -- Create a flattened closure
-    local sel = selectors.create_selector_closure(rspamd_config, rbl.selector, '', true)
+    if known_selectors[rbl.selector] then
+      lua_util.debugm(N, rspamd_config, 'reuse selector id %s',
+          known_selectors[rbl.selector].id)
+      rbl.selector = known_selectors[rbl.selector].selector
+      rbl.selector_id = known_selectors[rbl.selector].id
+    else
+      -- Create a new flattened closure
+      local sel = selectors.create_selector_closure(rspamd_config, rbl.selector, '', true)
 
-    if not sel then
-      rspamd_logger.errx('invalid selector for rbl rule %s: %s', key, rbl.selector)
-      return false
+      if not sel then
+        rspamd_logger.errx('invalid selector for rbl rule %s: %s', key, rbl.selector)
+        return false
+      end
+
+      rbl.selector = sel
+      known_selectors[rbl.selector] = {
+        selector = sel,
+        id = #lua_util.keys(known_selectors) + 1,
+      }
+      rbl.selector_id = known_selectors[rbl.selector].id
     end
 
-    rbl.selector = sel
   end
 
   if rbl.process_script then
@@ -896,22 +962,23 @@ end
 -- * RBL_CALLBACK_WHITE that depends on all symbols white
 -- * RBL_CALLBACK that depends on all symbols black to participate in depends chains
 local function rbl_callback_white(task)
-  local found_whitelist = false
+  local whitelisted_elements = {}
   for _, w in ipairs(white_symbols) do
-    if task:has_symbol(w) then
+    local ws = task:get_symbol(w)
+    if ws then
       lua_util.debugm(N, task,'found whitelist %s', w)
-      found_whitelist = true
-      break
+      if not ws.options then ws.options = {} end
+      for _,opt in ipairs(ws.options) do
+        local elt,what = opt:match('^([^:]+):([^:]+)')
+        if elt and what then
+          whitelisted_elements[elt] = what
+        end
+      end
     end
   end
 
-  if found_whitelist then
-    -- Disable all symbols black
-    for _, b in ipairs(black_symbols) do
-      lua_util.debugm(N, task,'disable %s, whitelist found', b)
-      task:disable_symbol(b)
-    end
-  end
+  task:cache_set('rbl_whitelisted', whitelisted_elements)
+
   lua_util.debugm(N, task, "finished rbl whitelists processing")
 end