diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-08-17 13:10:07 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-08-17 13:10:53 +0100 |
commit | 49229ac5796888ccfdfa12f188384da761a0c52c (patch) | |
tree | 897337d5bf8f7f05c5eee6ba2b5867841a6c430c /lualib | |
parent | 8958fa66850281cd04670b9d586d6ee071cb9131 (diff) | |
download | rspamd-49229ac5796888ccfdfa12f188384da761a0c52c.tar.gz rspamd-49229ac5796888ccfdfa12f188384da761a0c52c.zip |
[Feature] Add functional selectors library
Diffstat (limited to 'lualib')
-rw-r--r-- | lualib/lua_selectors.lua | 500 |
1 files changed, 500 insertions, 0 deletions
diff --git a/lualib/lua_selectors.lua b/lualib/lua_selectors.lua new file mode 100644 index 000000000..37b85e41d --- /dev/null +++ b/lualib/lua_selectors.lua @@ -0,0 +1,500 @@ +--[[ +Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- This module contains 'selectors' implementation: code to extract data +-- from Rspamd tasks and compose those together + +--[[[ +-- @module lua_selectors +-- This module contains 'selectors' implementation: code to extract data +-- from Rspamd tasks and compose those together. +-- Typical selector looks like this: header(User).lower.substring(1, 2):ip +--]] + +local exports = {} +local logger = require 'rspamd_logger' +local fun = require 'fun' +local lua_util = require "lua_util" +local M = "lua_selectors" +local E = {} + +local selectors = { + ['ip'] = { + ['type'] = 'addr', + ['get_value'] = function(task) + local ip = task:get_ip() + if ip and ip:is_valid() then return tostring(ip) end + return nil + end, + }, + ['from'] = { + ['type'] = 'email', + ['get_value'] = function(task) + local from = task:get_from(0) + if ((from or E)[1] or E).addr then + return string.lower(from[1]['addr']) + end + return nil + end, + }, + ['country'] = { + ['type'] = 'string', + ['get_value'] = function(task) + local asn = task:get_mempool():get_variable('asn') + if not asn then + return nil + else + return asn + end + end, + }, + ['asn'] = { + ['type'] = 'string', + ['get_value'] = function(task) + local asn = task:get_mempool():get_variable('asn') + if not asn then + return nil + else + return asn + end + end, + }, + ['user'] = { + ['type'] = 'string', + ['get_value'] = function(task) + local auser = task:get_user() + if not auser then + return nil + else + return auser + end + end, + }, + ['to'] = { + ['type'] = 'email', + ['get_value'] = function(task) + return task:get_principal_recipient() + end, + }, + ['digest'] = { + ['type'] = 'string', + ['get_value'] = function(task) + return task:get_digest() + end, + }, + ['attachments'] = { + ['type'] = 'string_list', + ['get_value'] = function(task) + local parts = task:get_parts() or E + local digests = {} + + for _,p in ipairs(parts) do + if p:get_filename() then + table.insert(digests, p:get_digest()) + end + end + + if #digests > 0 then + return digests + end + + return nil + end, + }, + ['files'] = { + ['type'] = 'string_list', + ['get_value'] = function(task) + local parts = task:get_parts() or E + local files = {} + + for _,p in ipairs(parts) do + local fname = p:get_filename() + if fname then + table.insert(files, fname) + end + end + + if #files > 0 then + return files + end + + return nil + end, + }, + ['helo'] = { + ['type'] = 'string', + ['get_value'] = function(task) + return task:get_helo() + end, + }, + ['header'] = { + ['type'] = 'header_list', + ['get_value'] = function(task, extra) + return task:get_header_full(extra) + end, + }, + ['url'] = { + ['type'] = 'url_list', + ['get_value'] = function(task) + return task:get_urls() + end, + }, + ['email'] = { + ['type'] = 'url_list', + ['get_value'] = function(task) + return task:get_emails() + end, + } +} + +local transform_function = { + -- Get hostname from url or a list of urls + ['get_host'] = { + ['types'] = { + ['url_list'] = true, + ['url'] = true + }, + ['process'] = function(inp, t) + if t == 'url_list' then + return fun.map(function(u) return u:get_host() end, inp),'string_list' + end + + return inp:get_host(),'string' + end + }, + -- Get tld from url or a list of urls + ['get_tld'] = { + ['types'] = { + ['url_list'] = true, + ['url'] = true + }, + ['process'] = function(inp, t) + if t == 'url_list' then + return fun.map(function(u) return u:get_tld() end, inp),'string_list' + end + + return inp:get_tld() + end + }, + -- Get address + ['get_addr'] = { + ['types'] = { + ['email'] = true + }, + ['process'] = function(inp, _) + return inp:get_addr() + end + }, + -- Returns the lowercased string + ['lower'] = { + ['types'] = { + ['string'] = true, + ['string_list'] = true + }, + ['process'] = function(inp, t) + if t == 'string' then + return inp:lower(),'string' + end + return fun.map(string.lower, inp),'string_list' + end + }, + -- Returns the first element + ['first'] = { + ['types'] = { + ['url_list'] = true, + ['header_list'] = true, + ['string_list'] = true + }, + ['process'] = function(inp, t) + if t == 'url_list' then + return inp[1],'url' + elseif t == 'header_list' then + return inp[1],'header' + elseif t == 'string_list' then + return inp[1],'string' + end + end + }, + -- Returns the last element + ['last'] = { + ['types'] = { + ['url_list'] = true, + ['header_list'] = true, + ['string_list'] = true + }, + ['process'] = function(inp, t) + if t == 'url_list' then + return inp[#inp],'url' + elseif t == 'header_list' then + return inp[#inp],'header' + elseif t == 'string_list' then + return inp[#inp],'string' + end + end + }, + -- Joins strings into a single string using separator in the argument + ['join'] = { + ['types'] = { + ['string_list'] = true + }, + ['process'] = function(inp, _, args) + return table.concat(inp, args[1] or ''), 'string' + end + }, + -- Create a digest from string or a list of strings + ['digest'] = { + ['types'] = { + ['string_list'] = true, + ['string'] = true + }, + ['process'] = function(inp, t, args) + local hash = require 'rspamd_cryptobox_hash' + local ht = args[1] or 'blake2' + + if t == 'string_list' then + return fun.map(function(s) + return hash:create_specific(ht):update(s) + end, inp),'hash_list' + end + + return hash:create_specific(ht):update(inp), 'hash' + end + }, + -- Encode hash to string (using hex encoding by default) + ['encode'] = { + ['types'] = { + ['hash_list'] = true, + ['hash'] = true + }, + ['process'] = function(inp, t, args) + local how = args[1] or 'hex' + if t == 'hash_list' then + return fun.map(function(s) + if how == 'hex' then + return s:hex() + elseif how == 'base32' then + return s:base32() + elseif how == 'base64' then + return s:base64() + end + end, inp),'string_list' + end + + if how == 'hex' then + return inp:hex() + elseif how == 'base32' then + return inp:base32() + elseif how == 'base64' then + return inp:base64() + end + end + }, + -- Extracts substring + ['substring'] = { + ['types'] = { + ['string'] = true + }, + ['process'] = function(inp, _, args) + local start_pos = args[1] or 1 + local end_pos = args[2] or -1 + + return inp:sub(start_pos, end_pos), 'string' + end + }, +} + +local function process_selector(task, sel) + local input = sel.selector.get_value(task, sel.selector.args) + if not input then return nil end + + -- Now we fold elements using left fold + local function fold_function(acc, x) + if acc == nil then return nil end + local value = acc[1] + local t = acc[2] + + if not x.types[t] then + logger.errx(task, 'cannot apply transform %s for type %s', x.name, t) + return nil + end + + return x.process(value, t, x.args) + end + + local res = fun.foldl(fold_function, + {input, sel.selector.type}, + sel.processor_pipe) + + if not res then return nil end -- Error in pipeline + + if not (res[2] == 'string' or res[2] == 'string_list') then + logger.errx(task, 'transform pipeline has returned bad type: %s, string expected', + res[2]) + return nil + end + + if res[2] == 'string_list' then + -- Convert to table as it might have a functional form + return fun.totable(res[1]) + end + + return res[1] +end + +local function make_grammar() + local l = require "lpeg" + local spc = l.S(" \t\n")^0 + local atom = l.C((l.R("az") + l.R("AZ") + l.R("09") + l.S("_-"))^1) + local dot = l.P(".") + local obrace = "(" * spc + local ebrace = spc * ")" + local comma = spc * "," * spc + local colon = ":" + + return l.P{ + "LIST"; + LIST = l.Ct(l.V("EXPR")) * (colon * l.Ct(l.V("EXPR")))^0, + EXPR = l.V("FUNCTION") * (dot * l.V("PROCESSOR"))^0, + PROCESSOR = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), + FUNCTION = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), + ARG_LIST = l.Ct((atom * comma^0)^0) + } +end + +local parser = make_grammar() + +--[[[ +-- @function lua_selectors.parse_selectors(cfg, str) +--]] +exports.parse_selector = function(cfg, str) + local parsed = parser:match(str) + local output = {} + + if not parsed then return nil end + local function shallowcopy(orig) + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in pairs(orig) do + copy[orig_key] = orig_value + end + else + copy = orig + end + return copy + end + + -- Output AST format is the following: + -- table of individual selectors + -- each selector: list of functions + -- each function: function name + optional list of arguments + for _,sel in ipairs(parsed) do + local res = { + selector = {}, + processor_pipe = {}, + } + + local selector_tbl = sel[1] + if not selector_tbl then + logger.errx(cfg, 'no selector represented') + return nil + end + if not selectors[selector_tbl[1]] then + logger.errx(cfg, 'selector %s is unknown', selector_tbl[1]) + return nil + end + + res.selector = shallowcopy(selectors[selector_tbl[1]]) + res.selector.name = selector_tbl[1] + res.selector.args = selector_tbl[2] or {} + + lua_util.debugm(M, cfg, 'processed selector %s, args: %s', + res.selector.name, res.selector.arg) + + -- Now process processors pipe + fun.each(function(proc_tbl) + local proc_name = proc_tbl[1] + + if not transform_function[proc_name] then + logger.errx(cfg, 'processor %s is unknown', proc_name) + return nil + end + local processor = shallowcopy(transform_function[proc_name]) + processor.name = proc_name + processor.args = proc_tbl[2] + lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s', + proc_name, res.selector.name, processor.args) + table.insert(res.processor_pipe, processor) + end, fun.tail(sel)) + + table.insert(output, res) + end + + return output +end + +--[[[ +-- @function lua_selectors.register_selector(cfg, name, selector) +--]] +exports.register_selector = function(cfg, name, selector) + if selector.get_value and selector.type then + if selectors[name] then + logger.warnx(cfg, 'redefining selector %s', name) + end + selectors[name] = selector + + return true + end + + logger.errx(cfg, 'bad selector %s', name) + return false +end + +--[[[ +-- @function lua_selectors.register_transform(cfg, name, transform) +--]] +exports.register_transform = function(cfg, name, transform) + if transform.process and transform.types then + if transform_function[name] then + logger.warnx(cfg, 'redefining transform function %s', name) + end + transform_function[name] = transform + + return true + end + + logger.errx(cfg, 'bad transform function %s', name) + return false +end + +--[[[ +-- @function lua_selectors.process_selectors(task, selectors_pipe) +--]] +exports.process_selectors = function(task, selectors_pipe) + local ret = fun.totable(fun.map(function(sel) + return process_selector(task, sel) + end, selectors_pipe)) + + if fun.any(function(e) return e == nil end, ret) then + -- If any element is nil, then the whole selector is nil + return nil + end + + return ret +end + +return exports
\ No newline at end of file |