]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Rbl: Add resolve_ip based RBLs
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 23 Aug 2019 15:46:08 +0000 (16:46 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 23 Aug 2019 15:46:08 +0000 (16:46 +0100)
src/plugins/lua/rbl.lua

index 35101efb41e04cfacfc0312ba3d3e6ed75a7f42d..7abe163b10c7d5043df667511ebab4bc11dfd1ec 100644 (file)
@@ -240,32 +240,45 @@ end
 
 local function gen_rbl_callback(rule)
 
-  local function add_dns_request(task, req, forced, requests_table)
+  local function add_dns_request(task, req, forced, is_ip, requests_table)
     if requests_table[req] then
       -- Duplicate request
       if forced and not requests_table[req].forced then
         requests_table[req].forced = true
       end
     else
+      local resolve_ip = rule.resolve_ip and not is_ip
       if rule.process_script then
-        local proc = rule.process_script(req, rule.rbl, task)
+        local processed = rule.process_script(req, rule.rbl, task, resolve_ip)
 
-        if proc then
+        if processed then
           local nreq = {
             forced = forced,
-            n = proc,
-            orig = req
+            n = processed,
+            orig = req,
+            resolve_ip = resolve_ip
           }
           requests_table[req] = nreq
         end
       else
-        local orign = maybe_make_hash(req, rule)
+        local to_resolve
+        local orign = req
+
+        if not resolve_ip then
+          orign = maybe_make_hash(req, rule)
+          to_resolve = string.format('%s.%s',
+              orign,
+              rule.rbl)
+        else
+          -- First, resolve origin stuff without hashing or anything
+          to_resolve = orign
+        end
+
         local nreq = {
           forced = forced,
-          n = string.format('%s.%s',
-              orign,
-              rule.rbl),
-          orig = orign
+          n = to_resolve,
+          orig = orign,
+          is_ip = resolve_ip
         }
         requests_table[req] = nreq
       end
@@ -316,7 +329,7 @@ local function gen_rbl_callback(rule)
       return false
     end
 
-    add_dns_request(task, helo, true, requests_table)
+    add_dns_request(task, helo, true, false, requests_table)
   end
 
   local function check_dkim(task, requests_table)
@@ -349,16 +362,16 @@ local function gen_rbl_callback(rule)
             end
 
             if mime_from_domain and mime_from_domain == domain_tld then
-              add_dns_request(task, domain_tld, true, requests_table)
+              add_dns_request(task, domain_tld, true, false, requests_table)
               ret = true
             end
           else
             if rule.dkim_domainonly then
               add_dns_request(task, rspamd_util.get_tld(domain),
-                  false, requests_table)
+                  false, false, requests_table)
               ret = true
             else
-              add_dns_request(task, domain, false, requests_table)
+              add_dns_request(task, domain, false, false, requests_table)
               ret = true
             end
           end
@@ -378,16 +391,16 @@ local function gen_rbl_callback(rule)
 
     for _,email in ipairs(emails) do
       if rule.emails_domainonly then
-        add_dns_request(task, email:get_tld(), false, requests_table)
+        add_dns_request(task, email:get_tld(), false, false, requests_table)
       else
         if rule.hash then
           -- Leave @ as is
           add_dns_request(task, string.format('%s@%s',
-              email:get_user(), email:get_host()), false, requests_table)
+              email:get_user(), email:get_host()), false, false, requests_table)
         else
           -- Replace @ with .
           add_dns_request(task, string.format('%s.%s',
-              email:get_user(), email:get_host()), false, requests_table)
+              email:get_user(), email:get_host()), false, false, requests_table)
         end
       end
     end
@@ -403,7 +416,7 @@ local function gen_rbl_callback(rule)
     end
     if (ip:get_version() == 6 and rule.ipv6) or
         (ip:get_version() == 4 and rule.ipv4) then
-      add_dns_request(task, ip_to_rbl(ip), true, requests_table)
+      add_dns_request(task, ip_to_rbl(ip), true, true, requests_table)
     end
 
     return true
@@ -419,7 +432,7 @@ local function gen_rbl_callback(rule)
 
     for pos,rh in ipairs(received) do
       if check_conditions(rh, pos) then
-        add_dns_request(task, ip_to_rbl(rh.real_ip), false, requests_table)
+        add_dns_request(task, ip_to_rbl(rh.real_ip), false, true, requests_table)
       end
     end
 
@@ -432,7 +445,7 @@ local function gen_rbl_callback(rule)
       return false
     end
 
-    add_dns_request(task, hostname, true, requests_table)
+    add_dns_request(task, hostname, true, false, requests_table)
 
     return true
   end
@@ -442,7 +455,7 @@ local function gen_rbl_callback(rule)
 
     if res then
       for _,r in ipairs(res) do
-        add_dns_request(task, r, false, requests_table)
+        add_dns_request(task, r, false, false, requests_table)
       end
     end
   end
@@ -509,19 +522,83 @@ local function gen_rbl_callback(rule)
 
     -- Now check all DNS requests pending and emit them
     local r = task:get_resolver()
-    for name,p in pairs(dns_req) do
-      if validate_dns(p.n) then
+    -- Used for 2 passes ip resolution
+    local resolved_req = {}
+    local nresolved = 0
+
+    -- This is called when doing resolve_ip phase...
+    local function gen_rbl_ip_dns_callback(orig)
+      return function(_, _, results, err)
+        if not err then
+          for _,dns_res in ipairs(results) do
+            -- Check if we have rspamd{ip} userdata
+            if type(dns_res) == 'userdata' then
+              -- Add result as an actual RBL request
+              add_dns_request(task, ip_to_rbl(dns_res), false, true,
+                  resolved_req)
+            end
+          end
+        end
+
+        nresolved = nresolved - 1
+
+        if nresolved == 0 then
+          -- Emit real RBL requests as there are no ip resolution requests
+          for name, req in pairs(resolved_req) do
+            if validate_dns(req.n) then
+              lua_util.debugm(N, task, "rbl %s; resolve %s -> %s",
+                  rule.symbol, name, req.n)
+              r:resolve_a({
+                task = task,
+                name = req.n,
+                callback = gen_rbl_dns_callback(orig),
+                forced = req.forced
+              })
+            else
+              rspamd_logger.warnx(task, 'cannot send invalid DNS request %s for %s',
+                  req.n, rule.symbol)
+            end
+          end
+        end
+      end
+    end
+
+    for name, req in pairs(dns_req) do
+      if validate_dns(req.n) then
         lua_util.debugm(N, task, "rbl %s; resolve %s -> %s",
-            rule.symbol, name, p.n)
-        r:resolve_a({
-          task = task,
-          name = p.n,
-          callback = gen_rbl_dns_callback(p.orig),
-          forced = p.forced
-        })
+            rule.symbol, name, req.n)
+
+        if req.resolve_ip then
+          -- Deal with both ipv4 and ipv6
+          -- Resolve names first
+          if r:resolve_a({
+            task = task,
+            name = req.n,
+            callback = gen_rbl_ip_dns_callback(req.orig),
+            forced = req.forced
+          }) then
+            nresolved = nresolved + 1
+          end
+          if r:resolve('aaaa', {
+            task = task,
+            name = req.n,
+            callback = gen_rbl_ip_dns_callback(req.orig),
+            forced = req.forced
+          }) then
+            nresolved = nresolved + 1
+          end
+        else
+          r:resolve_a({
+            task = task,
+            name = req.n,
+            callback = gen_rbl_dns_callback(req.orig),
+            forced = req.forced
+          })
+        end
+
       else
         rspamd_logger.warnx(task, 'cannot send invalid DNS request %s for %s',
-            p.n, rule.symbol)
+            req.n, rule.symbol)
       end
     end
   end
@@ -674,6 +751,7 @@ local default_options = {
   ['default_exclude_local'] = true,
   ['default_is_whitelist'] = false,
   ['default_ignore_whitelist'] = false,
+  ['default_resolve_ip'] = false,
 }
 
 opts = lua_util.override_defaults(default_options, opts)
@@ -721,6 +799,7 @@ local rule_schema = ts.shape({
   requests_limit = (ts.integer + ts.string / tonumber):is_optional(),
   process_script = ts.string:is_optional(),
 }, {
+  -- Covers boolean defaults
   extra_fields = ts.map_of(ts.string, ts.boolean)
 })