]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Settings: Rework settings check
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 11 Feb 2019 13:18:33 +0000 (13:18 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 11 Feb 2019 13:18:33 +0000 (13:18 +0000)
src/plugins/lua/settings.lua

index 26c4ac02b2b1d2086f10e3d060b9c283e6b57eed..6cea011923d5be4f6e8fde47c4f31d46fc082033 100644 (file)
@@ -199,126 +199,112 @@ local function check_settings(task)
     return false
   end
 
-  local function check_specific_setting(rule_name, rule, ip, client_ip, from, rcpt,
-      user, auth_user, hostname, matched)
+  local function check_specific_setting(rule_name, rule, data, matched)
     local res = false
 
-    if rule.authenticated then
-      if auth_user then
-        res = true
-        matched[#matched + 1] = 'authenticated'
-      end
-      if not res then
-        return nil
-      end
+    local function ip_valid(ip)
+      return ip:is_valid()
     end
 
-    if rule['local'] then
-      if not ip or not ip:is_valid() then
-        return nil
+    local function not_empty(s)
+      return #s > 0
+    end
+
+    local function generic_check(value, to_check, check_func, what, valid_func)
+      if not to_check then return true end
+
+      if type(value) == 'function' then
+        value = value()
       end
 
-      if ip:is_local() then
-        matched[#matched + 1] = 'local'
-        res = true
+      if value then
+        if valid_func then
+          if not valid_func(value) then
+            return false
+          end
+        end
+
+        if not check_func then
+          check_func = function(a, b) return a == b end
+        end
+
+        local ret = fun.any(function(d)
+          return check_func(value, d)
+        end, to_check)
+        if ret then
+          res = true
+          matched[#matched + 1] = what
+        else
+          return false
+        end
       else
-        return nil
+        return false
       end
+
+      return true
     end
 
-    if rule.ip then
-      if not ip or not ip:is_valid() then
-        return nil
-      end
-      for _, ip_check in ipairs(rule.ip) do
-        res = check_ip_setting(ip_check, ip)
-        if res then
-          matched[#matched + 1] = 'ip'
-          break
-        end
-      end
-      if not res then
-        return nil
-      end
+    if not generic_check(data.ip, rule.ip,
+        check_ip_setting, 'ip', ip_valid) then
+      return nil
     end
 
-    if rule.client_ip then
-      if not client_ip or not client_ip:is_valid() then
-        return nil
-      end
-      for _, ip_check in ipairs(rule.client_ip) do
-        res = check_ip_setting(ip_check, client_ip)
-        if res then
-          matched[#matched + 1] = 'client_ip'
-          break
-        end
-      end
-      if not res then
-        return nil
-      end
+    if not generic_check(data.client_ip, rule.client_ip,
+        check_ip_setting, 'client_ip', ip_valid) then
+      return nil
     end
 
-    if rule.from then
-      if not from then
-        return nil
-      end
-      for _, from_check in ipairs(rule.from) do
-        res = check_addr_setting(from_check, from)
-        if res then
-          matched[#matched + 1] = 'from'
-          break
-        end
-      end
-      if not res then
-        return nil
-      end
+    if not generic_check(data.from, rule.from,
+        check_addr_setting, 'from') then
+      return nil
     end
 
-    if rule.rcpt then
-      if not rcpt then
-        return nil
-      end
-      for _, rcpt_check in ipairs(rule.rcpt) do
-        res = check_addr_setting(rcpt_check, rcpt)
+    if not generic_check(data.from_mime, rule.from_mime,
+        check_addr_setting, 'from_mime') then
+      return nil
+    end
 
-        if res then
-          matched[#matched + 1] = 'rcpt'
-          break
-        end
-      end
-      if not res then
-        return nil
-      end
+    if not generic_check(data.rcpt, rule.rcpt,
+        check_addr_setting, 'rcpt') then
+      return nil
     end
 
-    if rule.user then
-      if not user then
-        return nil
-      end
-      for _, user_check in ipairs(rule.user) do
-        res = check_addr_setting(user_check, user)
-        if res then
-          matched[#matched + 1] = 'user'
-          break
-        end
+    if not generic_check(data.rcpt_mime, rule.rcpt_mime,
+        check_addr_setting, 'rcpt_mime') then
+      return nil
+    end
+
+    if not generic_check(data.user, rule.user,
+        check_addr_setting, 'user') then
+      return nil
+    end
+
+    if not generic_check(data.hostname, rule.hostname,
+        check_addr_setting, 'hostname', not_empty) then
+      return nil
+    end
+
+    -- Non generic checks
+
+    if rule.authenticated then
+      if data.user[1] then
+        res = true
+        matched[#matched + 1] = 'authenticated'
       end
       if not res then
         return nil
       end
     end
 
-    if rule.hostname then
-      if #hostname == 0 then
+    if rule['local'] then
+      if not data.ip or not data.ip:is_valid() then
         return nil
       end
-      for _, hname_check in ipairs(rule.hostname) do
-        res = check_addr_setting(hname_check, hostname)
-        if res then
-          matched[#matched + 1] = 'hostname'
-          break
-        end
-      end
-      if not res then
+
+      if data.ip:is_local() then
+        matched[#matched + 1] = 'local'
+        res = true
+      else
         return nil
       end
     end
@@ -391,23 +377,28 @@ local function check_settings(task)
   end
 
   lua_util.debugm(N, task, "check for settings")
-  local ip = task:get_from_ip()
-  local client_ip = task:get_client_ip()
-  local from = task:get_from()
-  local rcpt = task:get_recipients()
+  local data = {
+    ip = task:get_from_ip(),
+    client_ip = task:get_client_ip(),
+    from = task:get_from(1),
+    from_mime = task:get_from(2),
+    rcpt = task:get_recipients(1),
+    rcpt_mime = task:get_recipients(2),
+    hostname = task:get_hostname() or '',
+    user = {}
+  }
+
   local uname = task:get_user()
-  local hostname = task:get_hostname() or ''
-  local user = {}
   if uname then
-    user[1] = {}
+    data.user[1] = {}
     local localpart, domainpart = string.gmatch(uname, "(.+)@(.+)")()
     if localpart then
-      user[1]["user"] = localpart
-      user[1]["domain"] = domainpart
-      user[1]["addr"] = uname
+      data.user[1]["user"] = localpart
+      data.user[1]["domain"] = domainpart
+      data.user[1]["addr"] = uname
     else
-      user[1]["user"] = uname
-      user[1]["addr"] = uname
+      data.user[1]["user"] = uname
+      data.user[1]["addr"] = uname
     end
   end
   -- Match rules according their order
@@ -417,11 +408,10 @@ local function check_settings(task)
     if not applied and settings[pri] then
       for _,s in ipairs(settings[pri]) do
         local matched = {}
-        local rule = check_specific_setting(s.name, s.rule,
-            ip, client_ip, from, rcpt, user, uname, hostname, matched)
+        local result = check_specific_setting(s.name, s.rule, data, matched)
 
         -- Can use xor here but more complicated for reading
-        if (rule and not s.rule.inverse) or (not rule and s.rule.inverse) then
+        if (result and not s.rule.inverse) or (not result and s.rule.inverse) then
           rspamd_logger.infox(task, "<%s> apply settings according to rule %s (%s matched)",
             task:get_message_id(), s.name, table.concat(matched, ','))
           if s.rule['apply'] then