]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add functional selectors library
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 17 Aug 2018 12:10:07 +0000 (13:10 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 17 Aug 2018 12:10:53 +0000 (13:10 +0100)
lualib/lua_selectors.lua [new file with mode: 0644]

diff --git a/lualib/lua_selectors.lua b/lualib/lua_selectors.lua
new file mode 100644 (file)
index 0000000..37b85e4
--- /dev/null
@@ -0,0 +1,500 @@
+--[[
+Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+-- This module contains 'selectors' implementation: code to extract data
+-- from Rspamd tasks and compose those together
+
+--[[[
+-- @module lua_selectors
+-- This module contains 'selectors' implementation: code to extract data
+-- from Rspamd tasks and compose those together.
+-- Typical selector looks like this: header(User).lower.substring(1, 2):ip
+--]]
+
+local exports = {}
+local logger = require 'rspamd_logger'
+local fun = require 'fun'
+local lua_util = require "lua_util"
+local M = "lua_selectors"
+local E = {}
+
+local selectors = {
+  ['ip'] = {
+    ['type'] = 'addr',
+    ['get_value'] = function(task)
+      local ip = task:get_ip()
+      if ip and ip:is_valid() then return tostring(ip) end
+      return nil
+    end,
+  },
+  ['from'] = {
+    ['type'] = 'email',
+    ['get_value'] = function(task)
+      local from = task:get_from(0)
+      if ((from or E)[1] or E).addr then
+        return string.lower(from[1]['addr'])
+      end
+      return nil
+    end,
+  },
+  ['country'] = {
+    ['type'] = 'string',
+    ['get_value'] = function(task)
+      local asn = task:get_mempool():get_variable('asn')
+      if not asn then
+        return nil
+      else
+        return asn
+      end
+    end,
+  },
+  ['asn'] = {
+    ['type'] = 'string',
+    ['get_value'] = function(task)
+      local asn = task:get_mempool():get_variable('asn')
+      if not asn then
+        return nil
+      else
+        return asn
+      end
+    end,
+  },
+  ['user'] = {
+    ['type'] = 'string',
+    ['get_value'] = function(task)
+      local auser = task:get_user()
+      if not auser then
+        return nil
+      else
+        return auser
+      end
+    end,
+  },
+  ['to'] = {
+    ['type'] = 'email',
+    ['get_value'] = function(task)
+      return task:get_principal_recipient()
+    end,
+  },
+  ['digest'] = {
+    ['type'] = 'string',
+    ['get_value'] = function(task)
+      return task:get_digest()
+    end,
+  },
+  ['attachments'] = {
+    ['type'] = 'string_list',
+    ['get_value'] = function(task)
+      local parts = task:get_parts() or E
+      local digests = {}
+
+      for _,p in ipairs(parts) do
+        if p:get_filename() then
+          table.insert(digests, p:get_digest())
+        end
+      end
+
+      if #digests > 0 then
+        return digests
+      end
+
+      return nil
+    end,
+  },
+  ['files'] = {
+    ['type'] = 'string_list',
+    ['get_value'] = function(task)
+      local parts = task:get_parts() or E
+      local files = {}
+
+      for _,p in ipairs(parts) do
+        local fname = p:get_filename()
+        if fname then
+          table.insert(files, fname)
+        end
+      end
+
+      if #files > 0 then
+        return files
+      end
+
+      return nil
+    end,
+  },
+  ['helo'] = {
+    ['type'] = 'string',
+    ['get_value'] = function(task)
+      return task:get_helo()
+    end,
+  },
+  ['header'] = {
+    ['type'] = 'header_list',
+    ['get_value'] = function(task, extra)
+      return task:get_header_full(extra)
+    end,
+  },
+  ['url'] = {
+    ['type'] = 'url_list',
+    ['get_value'] = function(task)
+      return task:get_urls()
+    end,
+  },
+  ['email'] = {
+    ['type'] = 'url_list',
+    ['get_value'] = function(task)
+      return task:get_emails()
+    end,
+  }
+}
+
+local transform_function = {
+  -- Get hostname from url or a list of urls
+  ['get_host'] = {
+    ['types'] = {
+      ['url_list'] = true,
+      ['url'] = true
+    },
+    ['process'] = function(inp, t)
+      if t == 'url_list' then
+        return fun.map(function(u) return u:get_host() end, inp),'string_list'
+      end
+
+      return inp:get_host(),'string'
+    end
+  },
+  -- Get tld from url or a list of urls
+  ['get_tld'] = {
+    ['types'] = {
+      ['url_list'] = true,
+      ['url'] = true
+    },
+    ['process'] = function(inp, t)
+      if t == 'url_list' then
+        return fun.map(function(u) return u:get_tld() end, inp),'string_list'
+      end
+
+      return inp:get_tld()
+    end
+  },
+  -- Get address
+  ['get_addr'] = {
+    ['types'] = {
+      ['email'] = true
+    },
+    ['process'] = function(inp, _)
+      return inp:get_addr()
+    end
+  },
+  -- Returns the lowercased string
+  ['lower'] = {
+    ['types'] = {
+      ['string'] = true,
+      ['string_list'] = true
+    },
+    ['process'] = function(inp, t)
+      if t == 'string' then
+        return inp:lower(),'string'
+      end
+      return fun.map(string.lower, inp),'string_list'
+    end
+  },
+  -- Returns the first element
+  ['first'] = {
+    ['types'] = {
+      ['url_list'] = true,
+      ['header_list'] = true,
+      ['string_list'] = true
+    },
+    ['process'] = function(inp, t)
+      if t == 'url_list' then
+        return inp[1],'url'
+      elseif t == 'header_list' then
+        return inp[1],'header'
+      elseif t == 'string_list' then
+        return inp[1],'string'
+      end
+    end
+  },
+  -- Returns the last element
+  ['last'] = {
+    ['types'] = {
+      ['url_list'] = true,
+      ['header_list'] = true,
+      ['string_list'] = true
+    },
+    ['process'] = function(inp, t)
+      if t == 'url_list' then
+        return inp[#inp],'url'
+      elseif t == 'header_list' then
+        return inp[#inp],'header'
+      elseif t == 'string_list' then
+        return inp[#inp],'string'
+      end
+    end
+  },
+  -- Joins strings into a single string using separator in the argument
+  ['join'] = {
+    ['types'] = {
+      ['string_list'] = true
+    },
+    ['process'] = function(inp, _, args)
+      return table.concat(inp, args[1] or ''), 'string'
+    end
+  },
+  -- Create a digest from string or a list of strings
+  ['digest'] = {
+    ['types'] = {
+      ['string_list'] = true,
+      ['string'] = true
+    },
+    ['process'] = function(inp, t, args)
+      local hash = require 'rspamd_cryptobox_hash'
+      local ht = args[1] or 'blake2'
+
+      if t == 'string_list' then
+        return fun.map(function(s)
+          return hash:create_specific(ht):update(s)
+        end, inp),'hash_list'
+      end
+
+      return hash:create_specific(ht):update(inp), 'hash'
+    end
+  },
+  -- Encode hash to string (using hex encoding by default)
+  ['encode'] = {
+    ['types'] = {
+      ['hash_list'] = true,
+      ['hash'] = true
+    },
+    ['process'] = function(inp, t, args)
+      local how = args[1] or 'hex'
+      if t == 'hash_list' then
+        return fun.map(function(s)
+          if how == 'hex' then
+            return s:hex()
+          elseif how == 'base32' then
+            return s:base32()
+          elseif how == 'base64' then
+            return s:base64()
+          end
+        end, inp),'string_list'
+      end
+
+      if how == 'hex' then
+        return inp:hex()
+      elseif how == 'base32' then
+        return inp:base32()
+      elseif how == 'base64' then
+        return inp:base64()
+      end
+    end
+  },
+  -- Extracts substring
+  ['substring'] = {
+    ['types'] = {
+      ['string'] = true
+    },
+    ['process'] = function(inp, _, args)
+      local start_pos = args[1] or 1
+      local end_pos = args[2] or -1
+
+      return inp:sub(start_pos, end_pos), 'string'
+    end
+  },
+}
+
+local function process_selector(task, sel)
+  local input = sel.selector.get_value(task, sel.selector.args)
+  if not input then return nil end
+
+  -- Now we fold elements using left fold
+  local function fold_function(acc, x)
+    if acc == nil then return nil end
+    local value = acc[1]
+    local t = acc[2]
+
+    if not x.types[t] then
+      logger.errx(task, 'cannot apply transform %s for type %s', x.name, t)
+      return nil
+    end
+
+    return x.process(value, t, x.args)
+  end
+
+  local res = fun.foldl(fold_function,
+      {input, sel.selector.type},
+      sel.processor_pipe)
+
+  if not res then return nil end -- Error in pipeline
+
+  if not (res[2] == 'string' or res[2] == 'string_list') then
+    logger.errx(task, 'transform pipeline has returned bad type: %s, string expected',
+        res[2])
+    return nil
+  end
+
+  if res[2] == 'string_list' then
+    -- Convert to table as it might have a functional form
+    return fun.totable(res[1])
+  end
+
+  return res[1]
+end
+
+local function make_grammar()
+  local l = require "lpeg"
+  local spc = l.S(" \t\n")^0
+  local atom = l.C((l.R("az") + l.R("AZ") + l.R("09") + l.S("_-"))^1)
+  local dot = l.P(".")
+  local obrace = "(" * spc
+  local ebrace = spc * ")"
+  local comma = spc * "," * spc
+  local colon = ":"
+
+  return l.P{
+    "LIST";
+    LIST = l.Ct(l.V("EXPR")) * (colon * l.Ct(l.V("EXPR")))^0,
+    EXPR = l.V("FUNCTION") * (dot * l.V("PROCESSOR"))^0,
+    PROCESSOR = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0),
+    FUNCTION = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0),
+    ARG_LIST = l.Ct((atom * comma^0)^0)
+  }
+end
+
+local parser = make_grammar()
+
+--[[[
+-- @function lua_selectors.parse_selectors(cfg, str)
+--]]
+exports.parse_selector = function(cfg, str)
+  local parsed = parser:match(str)
+  local output = {}
+
+  if not parsed then return nil end
+  local function shallowcopy(orig)
+    local orig_type = type(orig)
+    local copy
+    if orig_type == 'table' then
+      copy = {}
+      for orig_key, orig_value in pairs(orig) do
+        copy[orig_key] = orig_value
+      end
+    else
+      copy = orig
+    end
+    return copy
+  end
+
+  -- Output AST format is the following:
+  -- table of individual selectors
+  -- each selector: list of functions
+  -- each function: function name + optional list of arguments
+  for _,sel in ipairs(parsed) do
+    local res = {
+      selector = {},
+      processor_pipe = {},
+    }
+
+    local selector_tbl = sel[1]
+    if not selector_tbl then
+      logger.errx(cfg, 'no selector represented')
+      return nil
+    end
+    if not selectors[selector_tbl[1]] then
+      logger.errx(cfg, 'selector %s is unknown', selector_tbl[1])
+      return nil
+    end
+
+    res.selector = shallowcopy(selectors[selector_tbl[1]])
+    res.selector.name = selector_tbl[1]
+    res.selector.args = selector_tbl[2] or {}
+
+    lua_util.debugm(M, cfg, 'processed selector %s, args: %s',
+        res.selector.name, res.selector.arg)
+
+    -- Now process processors pipe
+    fun.each(function(proc_tbl)
+      local proc_name = proc_tbl[1]
+
+      if not transform_function[proc_name] then
+        logger.errx(cfg, 'processor %s is unknown', proc_name)
+        return nil
+      end
+      local processor = shallowcopy(transform_function[proc_name])
+      processor.name = proc_name
+      processor.args = proc_tbl[2]
+      lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s',
+          proc_name, res.selector.name, processor.args)
+      table.insert(res.processor_pipe, processor)
+    end, fun.tail(sel))
+
+    table.insert(output, res)
+  end
+
+  return output
+end
+
+--[[[
+-- @function lua_selectors.register_selector(cfg, name, selector)
+--]]
+exports.register_selector = function(cfg, name, selector)
+  if selector.get_value and selector.type then
+    if selectors[name] then
+      logger.warnx(cfg, 'redefining selector %s', name)
+    end
+    selectors[name] = selector
+
+    return true
+  end
+
+  logger.errx(cfg, 'bad selector %s', name)
+  return false
+end
+
+--[[[
+-- @function lua_selectors.register_transform(cfg, name, transform)
+--]]
+exports.register_transform = function(cfg, name, transform)
+  if transform.process and transform.types then
+    if transform_function[name] then
+      logger.warnx(cfg, 'redefining transform function %s', name)
+    end
+    transform_function[name] = transform
+
+    return true
+  end
+
+  logger.errx(cfg, 'bad transform function %s', name)
+  return false
+end
+
+--[[[
+-- @function lua_selectors.process_selectors(task, selectors_pipe)
+--]]
+exports.process_selectors = function(task, selectors_pipe)
+  local ret = fun.totable(fun.map(function(sel)
+    return process_selector(task, sel)
+  end, selectors_pipe))
+
+  if fun.any(function(e) return e == nil end, ret) then
+    -- If any element is nil, then the whole selector is nil
+    return nil
+  end
+
+  return ret
+end
+
+return exports
\ No newline at end of file