]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Settings: add ip_map check and rework structure slightly
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 12 Jan 2021 13:54:58 +0000 (13:54 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 12 Jan 2021 13:54:58 +0000 (13:54 +0000)
src/plugins/lua/settings.lua

index 56a9f5079f9ed4bf93ca0b468638fb434f1ea522..bb0ec6f76d1f3d029cc355f7fb8a456caa20c55a 100644 (file)
@@ -23,7 +23,7 @@ end
 -- https://rspamd.com/doc/configuration/settings.html
 
 local rspamd_logger = require "rspamd_logger"
-local rspamd_maps = require "lua_maps"
+local lua_maps = require "lua_maps"
 local lua_util = require "lua_util"
 local rspamd_ip = require "rspamd_ip"
 local rspamd_regexp = require "rspamd_regexp"
@@ -189,17 +189,17 @@ end
 local function check_addr_setting(expected, addr)
   local function check_specific_addr(elt)
     if expected.name then
-      if rspamd_maps.rspamd_maybe_check_map(expected.name, elt.addr) then
+      if lua_maps.rspamd_maybe_check_map(expected.name, elt.addr) then
         return true
       end
     end
     if expected.user then
-      if rspamd_maps.rspamd_maybe_check_map(expected.user, elt.user) then
+      if lua_maps.rspamd_maybe_check_map(expected.user, elt.user) then
         return true
       end
     end
     if expected.domain and elt.domain then
-      if rspamd_maps.rspamd_maybe_check_map(expected.domain, elt.domain) then
+      if lua_maps.rspamd_maybe_check_map(expected.domain, elt.domain) then
         return true
       end
     end
@@ -226,7 +226,7 @@ local function check_string_setting(expected, str)
       return true
     end
   elseif expected.check then
-    if rspamd_maps.rspamd_maybe_check_map(expected.check, str) then
+    if lua_maps.rspamd_maybe_check_map(expected.check, str) then
       return true
     end
   end
@@ -235,7 +235,7 @@ end
 
 local function check_ip_setting(expected, ip)
   if not expected[2] then
-    if rspamd_maps.rspamd_maybe_check_map(expected[1], ip:to_string()) then
+    if lua_maps.rspamd_maybe_check_map(expected[1], ip:to_string()) then
       return true
     end
   else
@@ -252,6 +252,10 @@ local function check_ip_setting(expected, ip)
   return false
 end
 
+local function check_map_setting(map, input)
+  return map:get_key(input)
+end
+
 local function priority_to_string(pri)
   if pri then
     if pri >= 3 then
@@ -393,199 +397,206 @@ local function check_settings(task)
 
 end
 
--- Process settings based on their priority
-local function process_settings_table(tbl, allow_ids, is_static)
-  local get_priority = function(elt)
-    local pri_tonum = function(p)
-      if p then
-        if type(p) == "number" then
-          return tonumber(p)
-        elseif type(p) == "string" then
-          if p == "high" then
-            return 3
-          elseif p == "medium" then
-            return 2
-          end
-
-        end
-
-      end
-
-      return 1
-    end
-
-    return pri_tonum(elt['priority'])
+local function convert_to_table(chk_elt, out)
+  if type(chk_elt) == 'string' then
+    return {out}
   end
 
-  -- Check the setting element internal data
-  local process_setting_elt = function(name, elt)
+  return out
+end
 
-    lua_util.debugm(N, rspamd_config, 'process settings "%s"', name)
-    -- Process IP address: converted to a table {ip, mask}
-    local function process_ip_condition(ip)
-      local out = {}
+-- Process IP address: converted to a table {ip, mask}
+local function process_ip_condition(ip)
+  local out = {}
 
-      if type(ip) == "table" then
-        for _,v in ipairs(ip) do
-          table.insert(out, process_ip_condition(v))
-        end
-      elseif type(ip) == "string" then
-        local slash = string.find(ip, '/')
+  if type(ip) == "table" then
+    for _,v in ipairs(ip) do
+      table.insert(out, process_ip_condition(v))
+    end
+  elseif type(ip) == "string" then
+    local slash = string.find(ip, '/')
 
-        if not slash then
-          -- Just a plain IP address
-          local res = rspamd_ip.from_string(ip)
+    if not slash then
+      -- Just a plain IP address
+      local res = rspamd_ip.from_string(ip)
 
-          if res:is_valid() then
-            out[1] = res
-            out[2] = 0
-          else
-            -- It can still be a map
-            out[1] = res
-          end
-        else
-          local res = rspamd_ip.from_string(string.sub(ip, 1, slash - 1))
-          local mask = tonumber(string.sub(ip, slash + 1))
+      if res:is_valid() then
+        out[1] = res
+        out[2] = 0
+      else
+        -- It can still be a map
+        out[1] = res
+      end
+    else
+      local res = rspamd_ip.from_string(string.sub(ip, 1, slash - 1))
+      local mask = tonumber(string.sub(ip, slash + 1))
 
-          if res:is_valid() then
-            out[1] = res
-            out[2] = mask
-          else
-            rspamd_logger.errx(rspamd_config, "bad IP address: " .. ip)
-            return nil
-          end
-        end
+      if res:is_valid() then
+        out[1] = res
+        out[2] = mask
       else
+        rspamd_logger.errx(rspamd_config, "bad IP address: " .. ip)
         return nil
       end
-
-      return out
     end
+  else
+    return nil
+  end
 
-    -- Process email like condition, converted to a table with fields:
-    -- name - full email (surprise!)
-    -- user - user part
-    -- domain - domain part
-    -- regexp - full email regexp (yes, it sucks)
-    local function process_email_condition(addr)
-      local out = {}
-      if type(addr) == "table" then
-        for _,v in ipairs(addr) do
-          table.insert(out, process_email_condition(v))
+  return out
+end
+
+-- Process email like condition, converted to a table with fields:
+-- name - full email (surprise!)
+-- user - user part
+-- domain - domain part
+-- regexp - full email regexp (yes, it sucks)
+local function process_email_condition(addr)
+  local out = {}
+  if type(addr) == "table" then
+    for _,v in ipairs(addr) do
+      table.insert(out, process_email_condition(v))
+    end
+  elseif type(addr) == "string" then
+    if string.sub(addr, 1, 4) == "map:" then
+      -- It is map, don't apply any extra logic
+      out['name'] = addr
+    else
+      local start = string.sub(addr, 1, 1)
+      if start == '/' then
+        -- It is a regexp
+        local re = rspamd_regexp.create(addr)
+        if re then
+          out['regexp'] = re
+        else
+          rspamd_logger.errx(rspamd_config, "bad regexp: " .. addr)
+          return nil
         end
-      elseif type(addr) == "string" then
-        if string.sub(addr, 1, 4) == "map:" then
-          -- It is map, don't apply any extra logic
+
+      elseif start == '@' then
+        -- It is a domain if form @domain
+        out['domain'] = string.sub(addr, 2)
+      else
+        -- Check user@domain parts
+        local at = string.find(addr, '@')
+        if at then
+          -- It is full address
           out['name'] = addr
         else
-          local start = string.sub(addr, 1, 1)
-          if start == '/' then
-            -- It is a regexp
-            local re = rspamd_regexp.create(addr)
-            if re then
-              out['regexp'] = re
-            else
-              rspamd_logger.errx(rspamd_config, "bad regexp: " .. addr)
-              return nil
-            end
-
-          elseif start == '@' then
-            -- It is a domain if form @domain
-            out['domain'] = string.sub(addr, 2)
-          else
-            -- Check user@domain parts
-            local at = string.find(addr, '@')
-            if at then
-              -- It is full address
-              out['name'] = addr
-            else
-              -- It is a user
-              out['user'] = addr
-            end
-          end
+          -- It is a user
+          out['user'] = addr
         end
-      else
-        return nil
       end
-
-      return out
     end
+  else
+    return nil
+  end
 
-    -- Convert a plain string condition to a table:
-    -- check - string to match
-    -- regexp - regexp to match
-    local function process_string_condition(addr)
-      local out = {}
-      if type(addr) == "table" then
-        for _,v in ipairs(addr) do
-          table.insert(out, process_string_condition(v))
-        end
-      elseif type(addr) == "string" then
-        if string.sub(addr, 1, 4) == "map:" then
-          -- It is map, don't apply any extra logic
-          out['check'] = addr
-        else
-          local start = string.sub(addr, 1, 1)
-          if start == '/' then
-            -- It is a regexp
-            local re = rspamd_regexp.create(addr)
-            if re then
-              out['regexp'] = re
-            else
-              rspamd_logger.errx(rspamd_config, "bad regexp: " .. addr)
-              return nil
-            end
+  return out
+end
 
-          else
-            out['check'] = addr
-          end
+-- Convert a plain string condition to a table:
+-- check - string to match
+-- regexp - regexp to match
+local function process_string_condition(addr)
+  local out = {}
+  if type(addr) == "table" then
+    for _,v in ipairs(addr) do
+      table.insert(out, process_string_condition(v))
+    end
+  elseif type(addr) == "string" then
+    if string.sub(addr, 1, 4) == "map:" then
+      -- It is map, don't apply any extra logic
+      out['check'] = addr
+    else
+      local start = string.sub(addr, 1, 1)
+      if start == '/' then
+        -- It is a regexp
+        local re = rspamd_regexp.create(addr)
+        if re then
+          out['regexp'] = re
+        else
+          rspamd_logger.errx(rspamd_config, "bad regexp: " .. addr)
+          return nil
         end
+
       else
-        return nil
+        out['check'] = addr
       end
-
-      return out
     end
+  else
+    return nil
+  end
+
+  return out
+end
+
+local function get_priority (elt)
+  local pri_tonum = function(p)
+    if p then
+      if type(p) == "number" then
+        return tonumber(p)
+      elseif type(p) == "string" then
+        if p == "high" then
+          return 3
+        elseif p == "medium" then
+          return 2
+        end
 
-    local convert_to_table = function(chk_elt, out)
-      if type(chk_elt) == 'string' then
-        return {out}
       end
 
-      return out
     end
 
-    -- Used to create a checking closure: if value matches expected somehow, return true
-    local function gen_check_closure(expected, check_func)
-      return function(value)
-        if not value then return false end
+    return 1
+  end
 
-        if type(value) == 'function' then
-          value = value()
-        end
+  return pri_tonum(elt['priority'])
+end
 
-        if value then
+-- Used to create a checking closure: if value matches expected somehow, return true
+local function gen_check_closure(expected, check_func)
+  return function(value)
+    if not value then return false end
 
-          if not check_func then
-            check_func = function(a, b) return a == b end
-          end
+    if type(value) == 'function' then
+      value = value()
+    end
 
-          local ret = fun.any(function(d)
-            return check_func(d, value)
-          end, expected)
-          if ret then
-            return true
-          end
-        end
+    if value then
+
+      if not check_func then
+        check_func = function(a, b) return a == b end
+      end
 
-        return false
+      local ret
+      if type(expected) == 'table' then
+        ret = fun.any(function(d)
+          return check_func(d, value)
+        end, expected)
+      else
+        ret = check_func(expected, value)
+      end
+      if ret then
+        return true
       end
     end
 
+    return false
+  end
+end
+
+-- Process settings based on their priority
+local function process_settings_table(tbl, allow_ids, is_static)
+
+  -- Check the setting element internal data
+  local process_setting_elt = function(name, elt)
+
+    lua_util.debugm(N, rspamd_config, 'process settings "%s"', name)
+
     local out = {}
 
     local checks = {}
-    if elt['ip'] then
+    if elt.ip then
       local ips_table = process_ip_condition(elt['ip'])
 
       if ips_table then
@@ -601,8 +612,26 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['client_ip'] then
-      local client_ips_table = process_ip_condition(elt['client_ip'])
+    if elt.ip_map then
+      local ips_map = lua_maps.map_add_from_ucl(elt.ip_map, 'radix',
+          'settings ip map for ' .. name)
+
+      if ips_map then
+        lua_util.debugm(N, rspamd_config, 'added ip_map condition to "%s"',
+            name)
+        checks.ip_map = {
+          check = gen_check_closure(ips_map, check_map_setting),
+          extract = function(task)
+            local ip = task:get_from_ip()
+            if ip and ip:is_valid() then return ip end
+            return nil
+          end,
+        }
+      end
+    end
+
+    if elt.client_ip then
+      local client_ips_table = process_ip_condition(elt.client_ip)
 
       if client_ips_table then
         lua_util.debugm(N, rspamd_config, 'added client_ip condition to "%s": %s',
@@ -618,8 +647,26 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['from'] then
-      local from_condition = process_email_condition(elt['from'])
+    if elt.client_ip_map then
+      local ips_map = lua_maps.map_add_from_ucl(elt.ip_map, 'radix',
+          'settings client ip map for ' .. name)
+
+      if ips_map then
+        lua_util.debugm(N, rspamd_config, 'added client ip_map condition to "%s"',
+            name)
+        checks.client_ip_map = {
+          check = gen_check_closure(ips_map, check_map_setting),
+          extract = function(task)
+            local ip = task:get_client_ip()
+            if ip and ip:is_valid() then return ip end
+            return nil
+          end,
+        }
+      end
+    end
+
+    if elt.from then
+      local from_condition = process_email_condition(elt.from)
 
       if from_condition then
         lua_util.debugm(N, rspamd_config, 'added from condition to "%s": %s',
@@ -633,8 +680,9 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['rcpt'] then
-      local rcpt_condition = process_email_condition(elt['rcpt'])
+
+    if elt.rcpt then
+      local rcpt_condition = process_email_condition(elt.rcpt)
       if rcpt_condition then
         lua_util.debugm(N, rspamd_config, 'added rcpt condition to "%s": %s',
             name, rcpt_condition)
@@ -647,8 +695,9 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['from_mime'] then
-      local from_mime_condition = process_email_condition(elt['from_mime'])
+
+    if elt.from_mime then
+      local from_mime_condition = process_email_condition(elt.from_mime)
 
       if from_mime_condition then
         lua_util.debugm(N, rspamd_config, 'added from_mime condition to "%s": %s',
@@ -662,8 +711,9 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['rcpt_mime'] then
-      local rcpt_mime_condition = process_email_condition(elt['rcpt'])
+
+    if elt.rcpt_mime then
+      local rcpt_mime_condition = process_email_condition(elt.rcpt_mime)
       if rcpt_mime_condition then
         lua_util.debugm(N, rspamd_config, 'added rcpt mime condition to "%s": %s',
             name, rcpt_mime_condition)
@@ -676,8 +726,9 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['user'] then
-      local user_condition = process_email_condition(elt['user'])
+
+    if elt.user then
+      local user_condition = process_email_condition(elt.user)
       if user_condition then
         lua_util.debugm(N, rspamd_config, 'added user condition to "%s": %s',
             name, user_condition)
@@ -707,8 +758,9 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['hostname'] then
-      local hostname_condition = process_string_condition(elt['hostname'])
+
+    if elt.hostname then
+      local hostname_condition = process_string_condition(elt.hostname)
       if hostname_condition then
         lua_util.debugm(N, rspamd_config, 'added hostname condition to "%s": %s',
             name, hostname_condition)
@@ -721,7 +773,8 @@ local function process_settings_table(tbl, allow_ids, is_static)
         }
       end
     end
-    if elt['authenticated'] then
+
+    if elt.authenticated then
       lua_util.debugm(N, rspamd_config, 'added authenticated condition to "%s"',
           name)
       checks.authenticated = {
@@ -731,6 +784,7 @@ local function process_settings_table(tbl, allow_ids, is_static)
         end
       }
     end
+
     if elt['local'] then
       lua_util.debugm(N, rspamd_config, 'added local condition to "%s"',
           name)
@@ -785,7 +839,7 @@ local function process_settings_table(tbl, allow_ids, is_static)
       end
 
       return safe_key
-  end
+    end
     -- Headers are tricky:
     -- We create an closure with extraction function depending on header name
     -- We also inserts it into `checks` table as an atom in form header:<hname>
@@ -846,7 +900,7 @@ local function process_settings_table(tbl, allow_ids, is_static)
       end
     end)
 
-    if elt['selector'] then
+    if elt.selector then
       local sel = lua_selectors.create_selector_closure(rspamd_config, elt.selector,
           elt.delimiter or "")
 
@@ -868,13 +922,13 @@ local function process_settings_table(tbl, allow_ids, is_static)
 
     -- Special, special case!
     local inverse = false
-    if elt['inverse'] then
+    if elt.inverse then
       lua_util.debugm(N, rspamd_config, 'added inverse condition to "%s"',
           name)
       inverse = true
     end
 
-    -- Killmeplease
+    -- Count checks and create Rspamd expression from a set of rules
     local nchecks = 0
     for _,_ in pairs(checks) do nchecks = nchecks + 1 end