]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Implement implicit conversions to userdata
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 18 Sep 2018 14:38:09 +0000 (15:38 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 18 Sep 2018 14:38:30 +0000 (15:38 +0100)
lualib/lua_selectors.lua

index f57d069daed401ad5b99e9f038dba4a2fed99b00..1b97f64c5fbaab1169ef12403cf8ce889e495cc1 100644 (file)
@@ -47,7 +47,7 @@ For example, `id('Something')` returns a string 'Something']],
   ['ip'] = {
     ['get_value'] = function(task)
       local ip = task:get_ip()
-      if ip and ip:is_valid() then return ip,'ip' end
+      if ip and ip:is_valid() then return ip,'userdata' end
       return nil
     end,
     ['description'] = [[Get source IP address]],
@@ -57,7 +57,7 @@ For example, `id('Something')` returns a string 'Something']],
     ['get_value'] = function(task, args)
       local from = task:get_from(args[1] or 0)
       if ((from or E)[1] or E).addr then
-        return from[1],'email'
+        return from[1],'table'
       end
       return nil
     end,
@@ -68,7 +68,7 @@ uses any type by default)]],
     ['get_value'] = function(task, args)
       local rcpts = task:get_recipients(args[1] or 0)
       if ((rcpts or E)[1] or E).addr then
-        return rcpts,'email_list'
+        return rcpts,'table_list'
       end
       return nil
     end,
@@ -185,7 +185,7 @@ uses any type by default)]],
         end
 
         if args[2]:match('full') then
-          return task:get_header_full(args[1], strong),'kv_list'
+          return task:get_header_full(args[1], strong),'table_list'
         end
 
         return task:get_header(args[1], strong),'string'
@@ -206,7 +206,7 @@ The optional second argument accepts list of flags:
         return fun.map(function(r) return r[args[1]] end, rh), 'string_list'
       end
 
-      return rh,'kv_list'
+      return rh,'table_list'
     end,
     ['description'] = [[Get list of received headers.
 If no arguments specified, returns list of tables. Otherwise, selects a specific element,
@@ -219,7 +219,7 @@ e.g. `by_hostname`]],
       if args[1] and urls then
         return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list'
       end
-      return urls,'url_list'
+      return urls,'userdata_list'
     end,
     ['description'] = [[Get list of all urls.
 If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
@@ -232,7 +232,7 @@ e.g. `get_tld`]],
       if args[1] and urls then
         return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list'
       end
-      return urls,'url_list'
+      return urls,'userdata_list'
     end,
     ['description'] = [[Get list of all emails.
 If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
@@ -546,17 +546,6 @@ Returns either nil or its input if input is not in args list]],
   },
 }
 
-local implicit_types_map = {
-  ip = {'string', tostring},
-  email = {'string', function(e)
-    if e.name then
-      return string.format("%s <%s>", e.name, e.addr)
-    end
-    return string.format("<%s>", e.addr)
-  end},
-  url = {'string', tostring}
-}
-
 local function process_selector(task, sel)
   local function allowed_type(t)
     if t == 'string' or t == 'text' or t == 'string_list' or t == 'text_list' then
@@ -570,6 +559,21 @@ local function process_selector(task, sel)
     return pure_type(t)
   end
 
+  local function implicit_tostring(t, ud_or_table)
+    if t == 'userdata' then
+      return tostring(ud_or_table),'string'
+    else
+      -- Table (very special)
+      if ud_or_table.value then
+        return ud_or_table.value,'string'
+      elseif ud_or_table.addr then
+        return ud_or_table.addr,'string'
+      end
+
+      return table.concat(ud_or_table, " ")
+    end
+  end
+
   local input,etype = sel.selector.get_value(task, sel.selector.args)
 
   if not input then
@@ -577,7 +581,8 @@ local function process_selector(task, sel)
     return nil
   end
 
-  lua_util.debugm(M, task, 'extracted %s, type %s', sel.selector.name, etype)
+  lua_util.debugm(M, task, 'extracted %s, type %s',
+      sel.selector.name, etype)
 
   -- Now we fold elements using left fold
   local function fold_function(acc, x)
@@ -585,24 +590,45 @@ local function process_selector(task, sel)
       lua_util.debugm(M, task, 'do not apply %s, accumulator is nil', x.name)
       return nil
     end
+
     local value = acc[1]
     local t = acc[2]
 
     if not x.types[t] then
-      -- Additional case for map
-      local pt = pure_type(t)
-      if x.types['list'] then
-        -- Generic list
-        lua_util.debugm(M, task, 'apply list function %s to %s', x.name, t)
-        return {x.process(value, t, x.args)}
-      elseif pt and x.map_type and x.types[pt] then
-        lua_util.debugm(M, task, 'map %s to list of %s resulting %s',
-            x.name, pt, x.map_type)
-
-        return {fun.map(function(list_elt)
-          local ret, _ = x.process(list_elt, pt, x.args)
-          return ret
-        end, value), x.map_type}
+      -- Additional case for maps, tables and userdata
+      if t == 'userdata' or t == 'table' then
+        if not x.method then
+          -- Implicit conversion
+          lua_util.debugm(M, task, 'apply implicit conversion %s->string', t)
+          return fold_function({implicit_tostring(t, value)}, x)
+        end
+      else
+        local pt = pure_type(t)
+
+        if pt and x.types['list'] then
+          -- Generic list
+          lua_util.debugm(M, task, 'apply list function `%s` to %s', x.name, t)
+          return {x.process(value, t, x.args)}
+        elseif pt and x.map_type and x.types[pt] then
+          local map_type = x.map_type .. '_list'
+          lua_util.debugm(M, task, 'map `%s` to list of %s resulting %s',
+              x.name, pt, map_type)
+
+          return {fun.map(function(list_elt)
+            if not list_elt then return nil end
+            local ret, _ = x.process(list_elt, pt, x.args)
+            return ret
+          end, value), map_type}
+        elseif pt and pt == 'userdata' or pt == 'table' then
+          if not x.method then
+            -- Implicit conversion
+            lua_util.debugm(M, task, 'apply implicit map %s->string', pt)
+            return fold_function({fun.map(function(list_elt)
+              local ret, _ = implicit_tostring(pt, list_elt)
+              return ret
+            end, value), 'string_list'}, x)
+          end
+        end
       end
       logger.errx(task, 'cannot apply transform %s for type %s', x.name, t)
       return nil
@@ -624,38 +650,22 @@ local function process_selector(task, sel)
     local pt = pure_type(res[2])
 
     if pt then
-      local it = implicit_types_map[pt]
-      if it then
-        lua_util.debugm(M, task, 'apply implicit map %s->%s',
-            pt, it[1])
-        res[1] = fun.map(it[2], res[1])
-        res[2] = string.format('%s_list', it[1])
-      end
+      lua_util.debugm(M, task, 'apply implicit map %s->string_list', pt)
+      res[1] = fun.map(function(e) return implicit_tostring(pt, e) end, res[1])
+      res[2] = 'string_list'
     else
-      local it = implicit_types_map[res[2]]
-
-      if it then
-        lua_util.debugm(M, task, 'apply implicit conversion %s->%s',
-            res[2], it[1])
-        res[1] = it[2](res[1])
-        res[2] = it[1]
-      end
-    end
-
-    if not not allowed_type(res[2]) then
-      logger.errx(task, 'transform pipeline has returned bad type: %s, string expected: res = %s, sel: %s',
-          res[2], res, sel)
-      return nil
+      res[1] = implicit_tostring(res[2], res[1])
+      res[2] = 'string'
     end
   end
 
-  lua_util.debugm(M, task, 'final selector type: %s', res[2])
-
   if list_type(res[2]) then
     -- Convert to table as it might have a functional form
-    return fun.totable(res[1])
+    res[1] = fun.totable(res[1])
   end
 
+  lua_util.debugm(M, task, 'final selector type: %s, value: %s', res[2], res[1])
+
   return res[1]
 end
 
@@ -726,16 +736,43 @@ exports.parse_selector = function(cfg, str)
     fun.each(function(proc_tbl)
       local proc_name = proc_tbl[1]
 
-      if not transform_function[proc_name] then
-        logger.errx(cfg, 'processor %s is unknown', proc_name)
-        return nil
+      if proc_name:match('^__') then
+        -- Special case - method
+        local method_name = proc_name:match('^__(.*)$')
+        local processor = {
+          name = method_name,
+          method = true,
+          args = proc_tbl[2],
+          types = {
+            userdata = true,
+            table = true,
+          },
+          map_type = 'string',
+          process = function(inp, t, args)
+            if t == 'userdata' then
+              return inp[method_name](inp, args),string
+            else
+              -- Table
+              return inp[method_name],string
+            end
+          end,
+        }
+        lua_util.debugm(M, cfg, 'attached method %s to selector %s, args: %s',
+            proc_name, res.selector.name, processor.args)
+        table.insert(res.processor_pipe, processor)
+      else
+
+        if not transform_function[proc_name] then
+          logger.errx(cfg, 'processor %s is unknown', proc_name)
+          return nil
+        end
+        local processor = lua_util.shallowcopy(transform_function[proc_name])
+        processor.name = proc_name
+        processor.args = proc_tbl[2]
+        lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s',
+            proc_name, res.selector.name, processor.args)
+        table.insert(res.processor_pipe, processor)
       end
-      local processor = lua_util.shallowcopy(transform_function[proc_name])
-      processor.name = proc_name
-      processor.args = proc_tbl[2]
-      lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s',
-          proc_name, res.selector.name, processor.args)
-      table.insert(res.processor_pipe, processor)
     end, fun.tail(sel))
 
     table.insert(output, res)