]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Rework selectors logic
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 17 Sep 2018 12:27:38 +0000 (13:27 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 17 Sep 2018 12:27:38 +0000 (13:27 +0100)
lualib/lua_selectors.lua

index 40447a470df4d52f024000c2e83faa6608f3820e..b99ec3d9d146395e19d4b94d3af98e37aa5815c0 100644 (file)
@@ -33,61 +33,59 @@ local E = {}
 
 local extractors = {
   ['id'] = {
-    ['type'] = 'string',
     ['get_value'] = function(_, args)
       if args[1] then
-        return args[1]
+        return args[1], 'string'
       end
 
-      return ''
+      return '','string'
     end,
-    ['description'] = 'Return value from function\'s argument or an empty string',
+    ['description'] = [[Return value from function's argument or an empty string,
+For example, `id('Something')` returns a string 'Something']],
   },
   -- Get source IP address
   ['ip'] = {
-    ['type'] = 'ip',
     ['get_value'] = function(task)
       local ip = task:get_ip()
-      if ip and ip:is_valid() then return tostring(ip) end
+      if ip and ip:is_valid() then return ip,'ip' end
       return nil
     end,
-    ['description'] = 'Get source IP address',
+    ['description'] = [[Get source IP address]],
   },
   -- Get MIME from
   ['from'] = {
-    ['type'] = 'email',
     ['get_value'] = function(task, args)
       local from = task:get_from(args[1] or 0)
       if ((from or E)[1] or E).addr then
-        return from[1]
+        return from[1],'email'
       end
       return nil
     end,
-    ['description'] = 'Get MIME or SMTP from (e.g. from(\'smtp\') or from(\'mime\'), uses any type by default)',
+    ['description'] = [[Get MIME or SMTP from (e.g. from('smtp') or from(mime),
+uses any type by default)]],
   },
   ['rcpts'] = {
-    ['type'] = 'email_list',
     ['get_value'] = function(task, args)
       local rcpts = task:get_rcpt(args[1] or 0)
       if ((rcpts or E)[1] or E).addr then
-        return rcpts
+        return rcpts,'email_list'
       end
       return nil
     end,
-    ['description'] = 'Get MIME or SMTP recipients (e.g. rcpts(\'smtp\') or rcpts(\'mime\'), uses any type by default)',
+    ['description'] = [[Get MIME or SMTP rcpts (e.g. rcpts('smtp') or rcpts(mime),
+uses any type by default)]],
   },
   -- Get country (ASN module must be executed first)
   ['country'] = {
-    ['type'] = 'string',
     ['get_value'] = function(task)
-      local asn = task:get_mempool():get_variable('asn')
-      if not asn then
+      local country = task:get_mempool():get_variable('country')
+      if not country then
         return nil
       else
-        return asn
+        return country,'string'
       end
     end,
-    ['description'] = 'Get country (ASN module must be executed first)',
+    ['description'] = [[Get country (ASN module must be executed first)]],
   },
   -- Get ASN number
   ['asn'] = {
@@ -97,43 +95,39 @@ local extractors = {
       if not asn then
         return nil
       else
-        return asn
+        return asn,'string'
       end
     end,
-    ['description'] = 'Get ASN number',
+    ['description'] = [[Get AS number (ASN module must be executed first)]],
   },
   -- Get authenticated username
   ['user'] = {
-    ['type'] = 'string',
     ['get_value'] = function(task)
       local auser = task:get_user()
       if not auser then
         return nil
       else
-        return auser
+        return auser,'string'
       end
     end,
-    ['description'] = 'Get authenticated username',
+    ['description'] = 'Get authenticated user name',
   },
   -- Get principal recipient
   ['to'] = {
-    ['type'] = 'email',
     ['get_value'] = function(task)
-      return task:get_principal_recipient()
+      return task:get_principal_recipient(),'string'
     end,
     ['description'] = 'Get principal recipient',
   },
   -- Get content digest
   ['digest'] = {
-    ['type'] = 'string',
     ['get_value'] = function(task)
-      return task:get_digest()
+      return task:get_digest(),'string'
     end,
     ['description'] = 'Get content digest',
   },
   -- Get list of all attachments digests
   ['attachments'] = {
-    ['type'] = 'string_list',
     ['get_value'] = function(task)
       local parts = task:get_parts() or E
       local digests = {}
@@ -145,7 +139,7 @@ local extractors = {
       end
 
       if #digests > 0 then
-        return digests
+        return digests,'string_list'
       end
 
       return nil
@@ -154,7 +148,6 @@ local extractors = {
   },
   -- Get all attachments files
   ['files'] = {
-    ['type'] = 'string_list',
     ['get_value'] = function(task)
       local parts = task:get_parts() or E
       local files = {}
@@ -167,7 +160,7 @@ local extractors = {
       end
 
       if #files > 0 then
-        return files
+        return files,'string_list'
       end
 
       return nil
@@ -176,71 +169,99 @@ local extractors = {
   },
   -- Get helo value
   ['helo'] = {
-    ['type'] = 'string',
     ['get_value'] = function(task)
-      return task:get_helo()
+      return task:get_helo(),'string'
     end,
     ['description'] = 'Get helo value',
   },
   -- Get header with the name that is expected as an argument. Returns list of
   -- headers with this name
   ['header'] = {
-    ['type'] = 'kv_list',
     ['get_value'] = function(task, args)
-      return task:get_header_full(args[1])
+      local strong = false
+      if args[2] then
+        if args[2]:match('strong') then
+          strong = true
+        end
+
+        if args[2]:match('full') then
+          return task:get_header_full(args[1], strong),'kv_list'
+        end
+
+        return task:get_header(args[1], strong),'string'
+      else
+        return task:get_header(args[1]),'string'
+      end
     end,
-    ['description'] = 'Get header with the name that is expected as an argument. Returns list of headers with this name',
+    ['description'] = [[Get header with the name that is expected as an argument.
+The optional second argument accepts list of flags:
+  - `full`: returns all headers with this name with all data (like task:get_header_full())
+  - `strong`: use case sensitive match when matching header's name]],
   },
   -- Get list of received headers (returns list of tables)
   ['received'] = {
-    ['type'] = 'kv_list',
-    ['get_value'] = function(task)
-      return task:get_received_headers()
+    ['get_value'] = function(task, args)
+      local rh = task:get_received_headers()
+      if args[1] and rh then
+        return fun.map(function(r) return r[args[1]] end, rh), 'string_list'
+      end
+
+      return rh,'kv_list'
     end,
-    ['description'] = 'Get list of received headers (returns list of tables)',
+    ['description'] = [[Get list of received headers.
+If no arguments specified, returns list of tables. Otherwise, selects a specific element,
+e.g. `by_hostname`]],
   },
   -- Get all urls
   ['urls'] = {
-    ['type'] = 'url_list',
-    ['get_value'] = function(task)
-      return task:get_urls()
+    ['get_value'] = function(task, args)
+      local urls = task:get_urls()
+      if args[1] and urls then
+        return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list'
+      end
+      return urls,'url_list'
     end,
-    ['description'] = 'Get all urls',
+    ['description'] = [[Get list of all urls.
+If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
+e.g. `get_tld`]],
   },
   -- Get all emails
   ['emails'] = {
-    ['type'] = 'url_list',
-    ['get_value'] = function(task)
-      return task:get_emails()
+    ['get_value'] = function(task, args)
+      local urls = task:get_emails()
+      if args[1] and urls then
+        return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list'
+      end
+      return urls,'url_list'
     end,
-    ['description'] = 'Get all emails',
+    ['description'] = [[Get list of all emails.
+If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
+e.g. `get_user`]],
   },
   -- Get specific pool var. The first argument must be variable name,
   -- the second argument is optional and defines the type (string by default)
   ['pool_var'] = {
-    ['type'] = 'string',
     ['get_value'] = function(task, args)
-      return task:get_mempool():get_variable(args[1], args[2])
+      return task:get_mempool():get_variable(args[1], args[2]),(args[2] or 'string')
     end,
     ['description'] = [[Get specific pool var. The first argument must be variable name,
       the second argument is optional and defines the type (string by default)]],
   },
   -- Get specific HTTP request header. The first argument must be header name.
   ['request_header'] = {
-    ['type'] = 'string',
     ['get_value'] = function(task, args)
       local hdr = task:get_request_header(args[1])
       if hdr then
-        return tostring(hdr)
+        return tostring(hdr),'string'
       end
 
       return nil
     end,
-    ['description'] = 'Get specific HTTP request header. The first argument must be header name.',
+    ['description'] = [[Get specific HTTP request header.
+The first argument must be header name.]],
   },
   -- Get task date, optionally formatted
   ['time'] = {
-    ['type'] = 'string',
     ['get_value'] = function(task, args)
       local what = args[1] or 'message'
       local dt = task:get_date{format = what, gmt = true}
@@ -248,15 +269,18 @@ local extractors = {
       if dt then
         if args[2] then
           -- Should be in format !xxx, as dt is in GMT
-          return os.date(args[2], dt)
+          return os.date(args[2], dt),'string'
         end
 
-        return tostring(dt)
+        return tostring(dt),'string'
       end
 
       return nil
     end,
-    ['description'] = 'Get task date, optionally formatted (see os.date)',
+    ['description'] = [[Get task timestamp. The first argument is type:
+  - `connect`: connection timestamp (default)
+  - `message`: timestamp as defined by `Date` header
+The second argument is optional time format, see [os.date](http://pgl.yoyo.org/luai/i/os.date) description]]
   }
 }
 
@@ -493,7 +517,8 @@ local transform_function = {
       for _,a in ipairs(args) do if a == inp then return inp,t end end
       return nil
     end,
-    ['description'] = 'Boolean function in, returns either nil or its input if input is in args list',
+    ['description'] = [[Boolean function in.
+Returns either nil or its input if input is in args list]],
   },
   ['not_in'] = {
     ['types'] = {
@@ -504,12 +529,13 @@ local transform_function = {
       for _,a in ipairs(args) do if a == inp then return nil end end
       return inp,t
     end,
-    ['description'] = 'Boolean function in, returns either nil or its input if input is not in args list',
+    ['description'] = [[Boolean function not in.
+Returns either nil or its input if input is not in args list]],
   },
 }
 
 local function process_selector(task, sel)
-  local input = sel.selector.get_value(task, sel.selector.args)
+  local input,etype = sel.selector.get_value(task, sel.selector.args)
   if not input then return nil end
 
   -- Now we fold elements using left fold
@@ -538,7 +564,7 @@ local function process_selector(task, sel)
   end
 
   local res = fun.foldl(fold_function,
-      {input, sel.selector.type},
+      {input, etype},
       sel.processor_pipe)
 
   if not res or not res[1] then return nil end -- Pipeline failed