From: Vsevolod Stakhov Date: Mon, 19 Aug 2019 08:45:28 +0000 (+0100) Subject: [Rework] Reorganise selectors implementation X-Git-Tag: 2.0~395 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=58a5e0c7039af50b0607c5f4aee8648d1d560cdc;p=rspamd.git [Rework] Reorganise selectors implementation --- diff --git a/lualib/lua_selectors.lua b/lualib/lua_selectors.lua deleted file mode 100644 index c123ae637..000000000 --- a/lualib/lua_selectors.lua +++ /dev/null @@ -1,1212 +0,0 @@ ---[[ -Copyright (c) 2018, Vsevolod Stakhov - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -]]-- - --- This module contains 'selectors' implementation: code to extract data --- from Rspamd tasks and compose those together --- --- Read more at https://rspamd.com/doc/configuration/selectors.html - ---[[[ --- @module lua_selectors --- This module contains 'selectors' implementation: code to extract data --- from Rspamd tasks and compose those together. --- Typical selector looks like this: header(User).lower.substring(1, 2):ip ---]] - -local exports = { - maps = {} -- Defined for selectors maps, must be indexed by name -} - -local logger = require 'rspamd_logger' -local fun = require 'fun' -local lua_util = require "lua_util" -local ts = require("tableshape").types -local M = "selectors" -local E = {} - -local extractors = { - -- Plain id function - ['id'] = { - ['get_value'] = function(_, args) - if args[1] then - return args[1], 'string' - end - - return '','string' - end, - ['description'] = [[Return value from function's argument or an empty string, -For example, `id('Something')` returns a string 'Something']], - ['args_schema'] = {ts.string:is_optional()} - }, - -- Similar but for making lists - ['list'] = { - ['get_value'] = function(_, args) - if args[1] then - return fun.map(tostring, args), 'string_list' - end - - return {},'string_list' - end, - ['description'] = [[Return a list from function's arguments or an empty list, -For example, `list('foo', 'bar')` returns a list {'foo', 'bar'}]], - }, - -- Get source IP address - ['ip'] = { - ['get_value'] = function(task) - local ip = task:get_ip() - if ip and ip:is_valid() then return ip,'userdata' end - return nil - end, - ['description'] = [[Get source IP address]], - }, - -- Get MIME from - ['from'] = { - ['get_value'] = function(task, args) - local from = task:get_from(args[1] or 0) - if ((from or E)[1] or E).addr then - return from[1],'table' - end - return nil - end, - ['description'] = [[Get MIME or SMTP from (e.g. `from('smtp')` or `from('mime')`, -uses any type by default)]], - }, - ['rcpts'] = { - ['get_value'] = function(task, args) - local rcpts = task:get_recipients(args[1] or 0) - if ((rcpts or E)[1] or E).addr then - return rcpts,'table_list' - end - return nil - end, - ['description'] = [[Get MIME or SMTP rcpts (e.g. `rcpts('smtp')` or `rcpts('mime')`, -uses any type by default)]], - }, - -- Get country (ASN module must be executed first) - ['country'] = { - ['get_value'] = function(task) - local country = task:get_mempool():get_variable('country') - if not country then - return nil - else - return country,'string' - end - end, - ['description'] = [[Get country (ASN module must be executed first)]], - }, - -- Get ASN number - ['asn'] = { - ['type'] = 'string', - ['get_value'] = function(task) - local asn = task:get_mempool():get_variable('asn') - if not asn then - return nil - else - return asn,'string' - end - end, - ['description'] = [[Get AS number (ASN module must be executed first)]], - }, - -- Get authenticated username - ['user'] = { - ['get_value'] = function(task) - local auser = task:get_user() - if not auser then - return nil - else - return auser,'string' - end - end, - ['description'] = 'Get authenticated user name', - }, - -- Get principal recipient - ['to'] = { - ['get_value'] = function(task) - return task:get_principal_recipient(),'string' - end, - ['description'] = 'Get principal recipient', - }, - -- Get content digest - ['digest'] = { - ['get_value'] = function(task) - return task:get_digest(),'string' - end, - ['description'] = 'Get content digest', - }, - -- Get list of all attachments digests - ['attachments'] = { - ['get_value'] = function(task, args) - - local s - local parts = task:get_parts() or E - local digests = {} - - if #args > 0 then - local rspamd_cryptobox = require "rspamd_cryptobox_hash" - local encoding = args[1] or 'hex' - local ht = args[2] or 'blake2' - - for _,p in ipairs(parts) do - if p:get_filename() then - local h = rspamd_cryptobox.create_specific(ht, p:get_content('raw_parsed')) - if encoding == 'hex' then - s = h:hex() - elseif encoding == 'base32' then - s = h:base32() - elseif encoding == 'base64' then - s = h:base64() - end - table.insert(digests, s) - end - end - else - for _,p in ipairs(parts) do - if p:get_filename() then - table.insert(digests, p:get_digest()) - end - end - end - - if #digests > 0 then - return digests,'string_list' - end - - return nil - end, - ['description'] = [[Get list of all attachments digests. -The first optional argument is encoding (`hex`, `base32`, `base64`), -the second optional argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]], - - ['args_schema'] = {ts.one_of{'hex', 'base32', 'base64'}:is_optional(), - ts.one_of{'blake2', 'sha256', 'sha1', 'sha512', 'md5'}:is_optional()} - - }, - -- Get all attachments files - ['files'] = { - ['get_value'] = function(task) - local parts = task:get_parts() or E - local files = {} - - for _,p in ipairs(parts) do - local fname = p:get_filename() - if fname then - table.insert(files, fname) - end - end - - if #files > 0 then - return files,'string_list' - end - - return nil - end, - ['description'] = 'Get all attachments files', - }, - -- Get languages for text parts - ['languages'] = { - ['get_value'] = function(task) - local text_parts = task:get_text_parts() or E - local languages = {} - - for _,p in ipairs(text_parts) do - local lang = p:get_language() - if lang then - table.insert(languages, lang) - end - end - - if #languages > 0 then - return languages,'string_list' - end - - return nil - end, - ['description'] = 'Get languages for text parts', - }, - -- Get helo value - ['helo'] = { - ['get_value'] = function(task) - return task:get_helo(),'string' - end, - ['description'] = 'Get helo value', - }, - -- Get header with the name that is expected as an argument. Returns list of - -- headers with this name - ['header'] = { - ['get_value'] = function(task, args) - local strong = false - if args[2] then - if args[2]:match('strong') then - strong = true - end - - if args[2]:match('full') then - return task:get_header_full(args[1], strong),'table_list' - end - - return task:get_header(args[1], strong),'string' - else - return task:get_header(args[1]),'string' - end - end, - ['description'] = [[Get header with the name that is expected as an argument. -The optional second argument accepts list of flags: - - `full`: returns all headers with this name with all data (like task:get_header_full()) - - `strong`: use case sensitive match when matching header's name]], - ['args_schema'] = {ts.string, - (ts.pattern("strong") + ts.pattern("full")):is_optional()} - }, - -- Get list of received headers (returns list of tables) - ['received'] = { - ['get_value'] = function(task, args) - local rh = task:get_received_headers() - if args[1] and rh then - return fun.map(function(r) return r[args[1]] end, rh), 'string_list' - end - - return rh,'table_list' - end, - ['description'] = [[Get list of received headers. -If no arguments specified, returns list of tables. Otherwise, selects a specific element, -e.g. `by_hostname`]], - }, - -- Get all urls - ['urls'] = { - ['get_value'] = function(task, args) - local urls = task:get_urls() - if args[1] and urls then - return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list' - end - return urls,'userdata_list' - end, - ['description'] = [[Get list of all urls. -If no arguments specified, returns list of url objects. Otherwise, calls a specific method, -e.g. `get_tld`]], - }, - -- Get all emails - ['emails'] = { - ['get_value'] = function(task, args) - local urls = task:get_emails() - if args[1] and urls then - return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list' - end - return urls,'userdata_list' - end, - ['description'] = [[Get list of all emails. -If no arguments specified, returns list of url objects. Otherwise, calls a specific method, -e.g. `get_user`]], - }, - -- Get specific pool var. The first argument must be variable name, - -- the second argument is optional and defines the type (string by default) - ['pool_var'] = { - ['get_value'] = function(task, args) - local type = args[2] or 'string' - return task:get_mempool():get_variable(args[1], type),(type) - end, - ['description'] = [[Get specific pool var. The first argument must be variable name, -the second argument is optional and defines the type (string by default)]], - ['args_schema'] = {ts.string, ts.string:is_optional()} - }, - -- Get specific HTTP request header. The first argument must be header name. - ['request_header'] = { - ['get_value'] = function(task, args) - local hdr = task:get_request_header(args[1]) - if hdr then - return tostring(hdr),'string' - end - - return nil - end, - ['description'] = [[Get specific HTTP request header. -The first argument must be header name.]], - ['args_schema'] = {ts.string} - }, - -- Get task date, optionally formatted - ['time'] = { - ['get_value'] = function(task, args) - local what = args[1] or 'message' - local dt = task:get_date{format = what, gmt = true} - - if dt then - if args[2] then - -- Should be in format !xxx, as dt is in GMT - return os.date(args[2], dt),'string' - end - - return tostring(dt),'string' - end - - return nil - end, - ['description'] = [[Get task timestamp. The first argument is type: - - `connect`: connection timestamp (default) - - `message`: timestamp as defined by `Date` header - - The second argument is optional time format, see [os.date](http://pgl.yoyo.org/luai/i/os.date) description]], - ['args_schema'] = {ts.one_of{'connect', 'message'}:is_optional(), - ts.string:is_optional()} - }, - -- Get text words from a message - ['words'] = { - ['get_value'] = function(task, args) - local how = args[1] or 'stem' - local tp = task:get_text_parts() - - if tp then - local rtype = 'string_list' - if how == 'full' then - rtype = 'table_list' - end - - return lua_util.flatten( - fun.map(function(p) - return p:get_words(how) - end, tp)), rtype - end - - return nil - end, - ['description'] = [[Get words from text parts - - `stem`: stemmed words (default) - - `raw`: raw words - - `norm`: normalised words (lowercased) - - `full`: list of tables - ]], - ['args_schema'] = { ts.one_of { 'stem', 'raw', 'norm', 'full' }:is_optional()}, - }, -} - -local function pure_type(ltype) - return ltype:match('^(.*)_list$') -end - -local transform_function = { - -- Returns the lowercased string - ['lower'] = { - ['types'] = { - ['string'] = true, - }, - ['map_type'] = 'string', - ['process'] = function(inp, _) - return inp:lower(),'string' - end, - ['description'] = 'Returns the lowercased string', - }, - -- Returns the first element - ['first'] = { - ['types'] = { - ['list'] = true, - }, - ['process'] = function(inp, t) - return fun.head(inp),pure_type(t) - end, - ['description'] = 'Returns the first element', - }, - -- Returns the last element - ['last'] = { - ['types'] = { - ['list'] = true, - }, - ['process'] = function(inp, t) - return fun.nth(#inp, inp),pure_type(t) - end, - ['description'] = 'Returns the last element', - }, - -- Returns the nth element - ['nth'] = { - ['types'] = { - ['list'] = true, - }, - ['process'] = function(inp, t, args) - return fun.nth(args[1] or 1, inp),pure_type(t) - end, - ['description'] = 'Returns the nth element', - ['args_schema'] = {ts.number + ts.string / tonumber} - }, - ['take_n'] = { - ['types'] = { - ['list'] = true, - }, - ['process'] = function(inp, t, args) - return fun.take_n(args[1] or 1, inp),t - end, - ['description'] = 'Returns the n first elements', - ['args_schema'] = {ts.number + ts.string / tonumber} - }, - ['drop_n'] = { - ['types'] = { - ['list'] = true, - }, - ['process'] = function(inp, t, args) - return fun.drop_n(args[1] or 1, inp),t - end, - ['description'] = 'Returns list without the first n elements', - ['args_schema'] = {ts.number + ts.string / tonumber} - }, - -- Joins strings into a single string using separator in the argument - ['join'] = { - ['types'] = { - ['string_list'] = true - }, - ['process'] = function(inp, _, args) - return table.concat(fun.totable(inp), args[1] or ''), 'string' - end, - ['description'] = 'Joins strings into a single string using separator in the argument', - ['args_schema'] = {ts.string:is_optional()} - }, - -- Sort strings - ['sort'] = { - ['types'] = { - ['list'] = true - }, - ['process'] = function(inp, t, _) - table.sort(inp) - return inp, t - end, - ['description'] = 'Sort strings lexicographically', - }, - -- Return unique elements based on hashing (can work without sorting) - ['uniq'] = { - ['types'] = { - ['list'] = true - }, - ['process'] = function(inp, t, _) - local tmp = {} - fun.each(function(val) - tmp[val] = true - end, inp) - - return fun.map(function(k, _) return k end, tmp), t - end, - ['description'] = 'Returns a list of unique elements (using a hash table)', - }, - -- Create a digest from string or a list of strings - ['digest'] = { - ['types'] = { - ['string'] = true - }, - ['map_type'] = 'hash', - ['process'] = function(inp, _, args) - local hash = require 'rspamd_cryptobox_hash' - local encoding = args[1] or 'hex' - local ht = args[2] or 'blake2' - local h = hash:create_specific(ht):update(inp) - local s - - if encoding == 'hex' then - s = h:hex() - elseif encoding == 'base32' then - s = h:base32() - elseif encoding == 'base64' then - s = h:base64() - end - - return s,'string' - end, - ['description'] = [[Create a digest from a string. -The first argument is encoding (`hex`, `base32`, `base64`), -the second argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]], - ['args_schema'] = {ts.one_of{'hex', 'base32', 'base64'}:is_optional(), - ts.one_of{'blake2', 'sha256', 'sha1', 'sha512', 'md5'}:is_optional()} - }, - -- Extracts substring - ['substring'] = { - ['types'] = { - ['string'] = true - }, - ['map_type'] = 'string', - ['process'] = function(inp, _, args) - local start_pos = args[1] or 1 - local end_pos = args[2] or -1 - - return inp:sub(start_pos, end_pos), 'string' - end, - ['description'] = 'Extracts substring; the first argument is start, the second is the last (like in Lua)', - ['args_schema'] = {(ts.number + ts.string / tonumber):is_optional(), - (ts.number + ts.string / tonumber):is_optional()} - }, - -- Prepends a string or a strings list - ['prepend'] = { - ['types'] = { - ['string'] = true - }, - ['map_type'] = 'string', - ['process'] = function(inp, _, args) - local prepend = table.concat(args, '') - - return prepend .. inp, 'string' - end, - ['description'] = 'Prepends a string or a strings list', - }, - -- Appends a string or a strings list - ['append'] = { - ['types'] = { - ['string'] = true - }, - ['map_type'] = 'string', - ['process'] = function(inp, _, args) - local append = table.concat(args, '') - - return inp .. append, 'string' - end, - ['description'] = 'Appends a string or a strings list', - }, - -- Regexp matching - ['regexp'] = { - ['types'] = { - ['string'] = true - }, - ['map_type'] = 'string', - ['process'] = function(inp, _, args) - local rspamd_regexp = require "rspamd_regexp" - - local re = rspamd_regexp.create_cached(args[1]) - - if not re then - logger.errx('invalid regexp: %s', args[1]) - return nil - end - - local res = re:search(inp, false, true) - - if res then - if #res == 1 then - return res[1],'string' - end - - return res,'string_list' - end - - return nil - end, - ['description'] = 'Regexp matching', - ['args_schema'] = {ts.string} - }, - -- Returns a value if it exists in some map (or acts like a `filter` function) - ['filter_map'] = { - ['types'] = { - ['string'] = true - }, - ['map_type'] = 'string', - ['process'] = function(inp, t, args) - local map = exports.maps[args[1]] - - if not map then - logger.errx('invalid map name: %s', args[1]) - return nil - end - - local res = map:get_key(inp) - - if res then - return inp,t - end - - return nil - end, - ['description'] = 'Returns a value if it exists in some map (or acts like a `filter` function)', - ['args_schema'] = {ts.string} - }, - -- Returns a value from some map corresponding to some key (or acts like a `map` function) - ['apply_map'] = { - ['types'] = { - ['string'] = true - }, - ['map_type'] = 'string', - ['process'] = function(inp, t, args) - local map = exports.maps[args[1]] - - if not map then - logger.errx('invalid map name: %s', args[1]) - return nil - end - - local res = map:get_key(inp) - - if res then - return res,t - end - - return nil - end, - ['description'] = 'Returns a value from some map corresponding to some key (or acts like a `map` function)', - ['args_schema'] = {ts.string} - }, - -- Drops input value and return values from function's arguments or an empty string - ['id'] = { - ['types'] = { - ['string'] = true, - ['list'] = true, - }, - ['map_type'] = 'string', - ['process'] = function(_, _, args) - if args[1] and args[2] then - return fun.map(tostring, args),'string_list' - elseif args[1] then - return args[1],'string' - end - - return '','string' - end, - ['description'] = 'Drops input value and return values from function\'s arguments or an empty string', - ['args_schema'] = (ts.string + ts.array_of(ts.string)):is_optional() - }, - ['equal'] = { - ['types'] = { - ['string'] = true, - }, - ['map_type'] = 'string', - ['process'] = function(inp, _, args) - if inp == args[1] then - return inp,'string' - end - - return nil - end, - ['description'] = [[Boolean function equal. -Returns either nil or its argument if input is equal to argument]], - ['args_schema'] = {ts.string} - }, - -- Boolean function in, returns either nil or its input if input is in args list - ['in'] = { - ['types'] = { - ['string'] = true, - }, - ['map_type'] = 'string', - ['process'] = function(inp, t, args) - for _,a in ipairs(args) do if a == inp then return inp,t end end - return nil - end, - ['description'] = [[Boolean function in. -Returns either nil or its input if input is in args list]], - ['args_schema'] = ts.array_of(ts.string) - }, - ['not_in'] = { - ['types'] = { - ['string'] = true, - }, - ['map_type'] = 'string', - ['process'] = function(inp, t, args) - for _,a in ipairs(args) do if a == inp then return nil end end - return inp,t - end, - ['description'] = [[Boolean function not in. -Returns either nil or its input if input is not in args list]], - ['args_schema'] = ts.array_of(ts.string) - }, - ['inverse'] = { - ['types'] = { - ['string'] = true, - }, - ['map_type'] = 'string', - ['process'] = function(inp, _, args) - if inp then - return nil - else - return (args[1] or 'true'),'string' - end - end, - ['description'] = [[Inverses input. -Empty string comes the first argument or 'true', non-empty string comes nil]], - ['args_schema'] = {ts.string:is_optional()} - }, - ['ipmask'] = { - ['types'] = { - ['string'] = true, - }, - ['map_type'] = 'string', - ['process'] = function(inp, _, args) - local rspamd_ip = require "rspamd_ip" - -- Non optimal: convert string to an IP address - local ip = rspamd_ip.from_string(inp) - - if not ip or not ip:is_valid() then - lua_util.debugm(M, "cannot convert %s to IP", inp) - return nil - end - - if ip:get_version() == 4 then - local mask = tonumber(args[1]) - - return ip:apply_mask(mask):to_string(),'string' - else - -- IPv6 takes the second argument or the first one... - local mask_str = args[2] or args[1] - local mask = tonumber(mask_str) - - return ip:apply_mask(mask):to_string(),'string' - end - end, - ['description'] = 'Applies mask to IP address.' .. - ' The first argument is the mask for IPv4 addresses, the second is the mask for IPv6 addresses.', - ['args_schema'] = {(ts.number + ts.string / tonumber), - (ts.number + ts.string / tonumber):is_optional()} - }, -} - -transform_function.match = transform_function.regexp - -local function process_selector(task, sel) - local function allowed_type(t) - if t == 'string' or t == 'text' or t == 'string_list' or t == 'text_list' then - return true - end - - return false - end - - local function list_type(t) - return pure_type(t) - end - - local function implicit_tostring(t, ud_or_table) - if t == 'table' then - -- Table (very special) - if ud_or_table.value then - return ud_or_table.value,'string' - elseif ud_or_table.addr then - return ud_or_table.addr,'string' - end - - return logger.slog("%s", ud_or_table),'string' - else - return tostring(ud_or_table),'string' - end - end - - local input,etype = sel.selector.get_value(task, sel.selector.args) - - if not input then - lua_util.debugm(M, task, 'no value extracted for %s', sel.selector.name) - return nil - end - - lua_util.debugm(M, task, 'extracted %s, type %s', - sel.selector.name, etype) - - local pipe = sel.processor_pipe or E - - if etype:match('^userdata') or etype:match('^table') then - -- Apply userdata conversion first - local first_elt = pipe[1] - - if first_elt and first_elt.method then - -- Explicit conversion - local meth = first_elt - - if meth.types[etype] then - lua_util.debugm(M, task, 'apply method `%s` to %s', - meth.name, etype) - input,etype = meth.process(input, etype) - else - local pt = pure_type(etype) - - if meth.types[pt] then - lua_util.debugm(M, task, 'map method `%s` to list of %s', - meth.name, pt) - -- Map method to a list of inputs, excluding empty elements - input = fun.filter(function(map_elt) return map_elt end, - fun.map(function(list_elt) - local ret, _ = meth.process(list_elt, pt) - return ret - end, input)) - etype = 'string_list' - end - end - -- Remove method from the pipeline - pipe = fun.drop_n(1, pipe) - else - -- Implicit conversion - - local pt = pure_type(etype) - - if not pt then - lua_util.debugm(M, task, 'apply implicit conversion %s->string', etype) - input = implicit_tostring(etype, input) - etype = 'string' - else - lua_util.debugm(M, task, 'apply implicit map %s->string', pt) - input = fun.filter(function(map_elt) return map_elt end, - fun.map(function(list_elt) - local ret = implicit_tostring(pt, list_elt) - return ret - end, input)) - etype = 'string_list' - end - end - end - - -- Now we fold elements using left fold - local function fold_function(acc, x) - if acc == nil or acc[1] == nil then - lua_util.debugm(M, task, 'do not apply %s, accumulator is nil', x.name) - return nil - end - - local value = acc[1] - local t = acc[2] - - if not x.types[t] then - local pt = pure_type(t) - - if pt and x.types['list'] then - -- Generic list processor - lua_util.debugm(M, task, 'apply list function `%s` to %s', x.name, t) - return {x.process(value, t, x.args)} - elseif pt and x.map_type and x.types[pt] then - local map_type = x.map_type .. '_list' - lua_util.debugm(M, task, 'map `%s` to list of %s resulting %s', - x.name, pt, map_type) - -- Apply map, filtering empty values - return { - fun.filter(function(map_elt) return map_elt end, - fun.map(function(list_elt) - if not list_elt then return nil end - local ret, _ = x.process(list_elt, pt, x.args) - return ret - end, value)), - map_type -- Returned type - } - end - logger.errx(task, 'cannot apply transform %s for type %s', x.name, t) - return nil - end - - lua_util.debugm(M, task, 'apply %s to %s', x.name, t) - return {x.process(value, t, x.args)} - end - - local res = fun.foldl(fold_function, - {input, etype}, - pipe) - - if not res or not res[1] then return nil end -- Pipeline failed - - if not allowed_type(res[2]) then - - -- Search for implicit conversion - local pt = pure_type(res[2]) - - if pt then - lua_util.debugm(M, task, 'apply implicit map %s->string_list', pt) - res[1] = fun.map(function(e) return implicit_tostring(pt, e) end, res[1]) - res[2] = 'string_list' - else - res[1] = implicit_tostring(res[2], res[1]) - res[2] = 'string' - end - end - - if list_type(res[2]) then - -- Convert to table as it might have a functional form - res[1] = fun.totable(res[1]) - end - - lua_util.debugm(M, task, 'final selector type: %s, value: %s', res[2], res[1]) - - return res[1] -end - -local function make_grammar() - local l = require "lpeg" - local spc = l.S(" \t\n")^0 - local atom = l.C((l.R("az") + l.R("AZ") + l.R("09") + l.S("_-"))^1) - local singlequoted_string = l.P "'" * l.C(((1 - l.S "'\r\n\f\\") + (l.P'\\' * 1))^0) * "'" - local doublequoted_string = l.P '"' * l.C(((1 - l.S'"\r\n\f\\') + (l.P'\\' * 1))^0) * '"' - local argument = atom + singlequoted_string + doublequoted_string - local dot = l.P(".") - local semicolon = l.P(":") - local obrace = "(" * spc - local ebrace = spc * ")" - local comma = spc * "," * spc - local sel_separator = spc * l.S";*" * spc - - return l.P{ - "LIST"; - LIST = l.Ct(l.V("EXPR")) * (sel_separator * l.Ct(l.V("EXPR")))^0, - EXPR = l.V("FUNCTION") * (semicolon * l.V("METHOD"))^-1 * (dot * l.V("PROCESSOR"))^0, - PROCESSOR = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), - FUNCTION = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), - METHOD = l.Ct(atom / function(e) return '__' .. e end * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), - ARG_LIST = l.Ct((argument * comma^0)^0) - } -end - -local parser = make_grammar() - ---[[[ --- @function lua_selectors.parse_selector(cfg, str) ---]] -exports.parse_selector = function(cfg, str) - local parsed = {parser:match(str)} - local output = {} - - if not parsed or not parsed[1] then return nil end - - local function check_args(name, schema, args) - if schema then - if getmetatable(schema) then - -- Schema covers all arguments - local res,err = schema:transform(args) - if not res then - logger.errx(rspamd_config, 'invalid arguments for %s: %s', name, err) - return false - else - for i,elt in ipairs(res) do - args[i] = elt - end - end - else - for i,selt in ipairs(schema) do - local res,err = selt:transform(args[i]) - - if err then - logger.errx(rspamd_config, 'invalid arguments for %s: %s', name, err) - return false - else - args[i] = res - end - end - end - end - - return true - end - - -- Output AST format is the following: - -- table of individual selectors - -- each selector: list of functions - -- each function: function name + optional list of arguments - for _,sel in ipairs(parsed) do - local res = { - selector = {}, - processor_pipe = {}, - } - - local selector_tbl = sel[1] - if not selector_tbl then - logger.errx(cfg, 'no selector represented') - return nil - end - if not extractors[selector_tbl[1]] then - logger.errx(cfg, 'selector %s is unknown', selector_tbl[1]) - return nil - end - - res.selector = lua_util.shallowcopy(extractors[selector_tbl[1]]) - res.selector.name = selector_tbl[1] - res.selector.args = selector_tbl[2] or E - - if not check_args(res.selector.name, - res.selector.args_schema, - res.selector.args) then - return nil - end - - lua_util.debugm(M, cfg, 'processed selector %s, args: %s', - res.selector.name, res.selector.args) - - local pipeline_error = false - -- Now process processors pipe - fun.each(function(proc_tbl) - local proc_name = proc_tbl[1] - - if proc_name:match('^__') then - -- Special case - method - local method_name = proc_name:match('^__(.*)$') - -- Check array indexing... - if tonumber(method_name) then - method_name = tonumber(method_name) - end - local processor = { - name = tostring(method_name), - method = true, - args = proc_tbl[2] or E, - types = { - userdata = true, - table = true, - }, - map_type = 'string', - process = function(inp, t, args) - if t == 'userdata' then - return inp[method_name](inp, args),'string' - else - -- Table - return inp[method_name],'string' - end - end, - } - lua_util.debugm(M, cfg, 'attached method %s to selector %s, args: %s', - proc_name, res.selector.name, processor.args) - table.insert(res.processor_pipe, processor) - else - - if not transform_function[proc_name] then - logger.errx(cfg, 'processor %s is unknown', proc_name) - pipeline_error = proc_name - return nil - end - local processor = lua_util.shallowcopy(transform_function[proc_name]) - processor.name = proc_name - processor.args = proc_tbl[2] or E - - if not check_args(processor.name, processor.args_schema, processor.args) then - pipeline_error = 'args schema for ' .. proc_name - return nil - end - - lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s', - proc_name, res.selector.name, processor.args) - table.insert(res.processor_pipe, processor) - end - end, fun.tail(sel)) - - if pipeline_error then - logger.errx(cfg, 'unknown or invalid processor used: "%s", exiting', pipeline_error) - return nil - end - - table.insert(output, res) - end - - return output -end - ---[[[ --- @function lua_selectors.register_extractor(cfg, name, selector) ---]] -exports.register_extractor = function(cfg, name, selector) - if selector.get_value then - if extractors[name] then - logger.warnx(cfg, 'redefining selector %s', name) - end - extractors[name] = selector - - return true - end - - logger.errx(cfg, 'bad selector %s', name) - return false -end - ---[[[ --- @function lua_selectors.register_transform(cfg, name, transform) ---]] -exports.register_transform = function(cfg, name, transform) - if transform.process and transform.types then - if transform_function[name] then - logger.warnx(cfg, 'redefining transform function %s', name) - end - transform_function[name] = transform - - return true - end - - logger.errx(cfg, 'bad transform function %s', name) - return false -end - ---[[[ --- @function lua_selectors.process_selectors(task, selectors_pipe) ---]] -exports.process_selectors = function(task, selectors_pipe) - local ret = {} - - for _,sel in ipairs(selectors_pipe) do - local r = process_selector(task, sel) - - -- If any element is nil, then the whole selector is nil - if not r then return nil end - table.insert(ret, r) - end - - return ret -end - ---[[[ --- @function lua_selectors.combine_selectors(task, selectors, delimiter) ---]] -exports.combine_selectors = function(_, selectors, delimiter) - if not delimiter then delimiter = '' end - - if not selectors then return nil end - - local all_strings = fun.all(function(s) return type(s) == 'string' end, selectors) - - if all_strings then - return table.concat(selectors, delimiter) - else - -- We need to do a spill on each table selector - -- e.g. s:tbl:s -> s:telt1:s + s:telt2:s ... - local tbl = {} - local res = {} - - for i,s in ipairs(selectors) do - if type(s) == 'string' then - rawset(tbl, i, fun.duplicate(s)) - elseif type(s) == 'userdata' then - rawset(tbl, i, fun.duplicate(tostring(s))) - else - rawset(tbl, i, s) - end - end - - fun.each(function(...) - table.insert(res, table.concat({...}, delimiter)) - end, fun.zip(lua_util.unpack(tbl))) - - return res - end -end - ---[[[ --- @function lua_selectors.create_closure(cfg, selector_str, delimiter='') ---]] -exports.create_selector_closure = function(cfg, selector_str, delimiter) - local selector = exports.parse_selector(cfg, selector_str) - - if not selector then - return nil - end - - return function(task) - local res = exports.process_selectors(task, selector) - - if res then - return exports.combine_selectors(nil, res, delimiter) - end - - return nil - end -end - -local function display_selectors(tbl) - return fun.tomap(fun.map(function(k,v) - return k, fun.tomap(fun.filter(function(kk, vv) - return type(vv) ~= 'function' - end, v)) - end, tbl)) -end - -exports.list_extractors = function() - return display_selectors(extractors) -end - -exports.list_transforms = function() - return display_selectors(transform_function) -end - -return exports diff --git a/lualib/lua_selectors/extractors.lua b/lualib/lua_selectors/extractors.lua new file mode 100644 index 000000000..a3fda891d --- /dev/null +++ b/lualib/lua_selectors/extractors.lua @@ -0,0 +1,374 @@ +--[[ +Copyright (c) 2019, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local fun = require 'fun' +local lua_util = require "lua_util" +local ts = require("tableshape").types +local E = {} + +local extractors = { + -- Plain id function + ['id'] = { + ['get_value'] = function(_, args) + if args[1] then + return args[1], 'string' + end + + return '','string' + end, + ['description'] = [[Return value from function's argument or an empty string, +For example, `id('Something')` returns a string 'Something']], + ['args_schema'] = {ts.string:is_optional()} + }, + -- Similar but for making lists + ['list'] = { + ['get_value'] = function(_, args) + if args[1] then + return fun.map(tostring, args), 'string_list' + end + + return {},'string_list' + end, + ['description'] = [[Return a list from function's arguments or an empty list, +For example, `list('foo', 'bar')` returns a list {'foo', 'bar'}]], + }, + -- Get source IP address + ['ip'] = { + ['get_value'] = function(task) + local ip = task:get_ip() + if ip and ip:is_valid() then return ip,'userdata' end + return nil + end, + ['description'] = [[Get source IP address]], + }, + -- Get MIME from + ['from'] = { + ['get_value'] = function(task, args) + local from = task:get_from(args[1] or 0) + if ((from or E)[1] or E).addr then + return from[1],'table' + end + return nil + end, + ['description'] = [[Get MIME or SMTP from (e.g. `from('smtp')` or `from('mime')`, +uses any type by default)]], + }, + ['rcpts'] = { + ['get_value'] = function(task, args) + local rcpts = task:get_recipients(args[1] or 0) + if ((rcpts or E)[1] or E).addr then + return rcpts,'table_list' + end + return nil + end, + ['description'] = [[Get MIME or SMTP rcpts (e.g. `rcpts('smtp')` or `rcpts('mime')`, +uses any type by default)]], + }, + -- Get country (ASN module must be executed first) + ['country'] = { + ['get_value'] = function(task) + local country = task:get_mempool():get_variable('country') + if not country then + return nil + else + return country,'string' + end + end, + ['description'] = [[Get country (ASN module must be executed first)]], + }, + -- Get ASN number + ['asn'] = { + ['type'] = 'string', + ['get_value'] = function(task) + local asn = task:get_mempool():get_variable('asn') + if not asn then + return nil + else + return asn,'string' + end + end, + ['description'] = [[Get AS number (ASN module must be executed first)]], + }, + -- Get authenticated username + ['user'] = { + ['get_value'] = function(task) + local auser = task:get_user() + if not auser then + return nil + else + return auser,'string' + end + end, + ['description'] = 'Get authenticated user name', + }, + -- Get principal recipient + ['to'] = { + ['get_value'] = function(task) + return task:get_principal_recipient(),'string' + end, + ['description'] = 'Get principal recipient', + }, + -- Get content digest + ['digest'] = { + ['get_value'] = function(task) + return task:get_digest(),'string' + end, + ['description'] = 'Get content digest', + }, + -- Get list of all attachments digests + ['attachments'] = { + ['get_value'] = function(task, args) + + local s + local parts = task:get_parts() or E + local digests = {} + + if #args > 0 then + local rspamd_cryptobox = require "rspamd_cryptobox_hash" + local encoding = args[1] or 'hex' + local ht = args[2] or 'blake2' + + for _,p in ipairs(parts) do + if p:get_filename() then + local h = rspamd_cryptobox.create_specific(ht, p:get_content('raw_parsed')) + if encoding == 'hex' then + s = h:hex() + elseif encoding == 'base32' then + s = h:base32() + elseif encoding == 'base64' then + s = h:base64() + end + table.insert(digests, s) + end + end + else + for _,p in ipairs(parts) do + if p:get_filename() then + table.insert(digests, p:get_digest()) + end + end + end + + if #digests > 0 then + return digests,'string_list' + end + + return nil + end, + ['description'] = [[Get list of all attachments digests. +The first optional argument is encoding (`hex`, `base32`, `base64`), +the second optional argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]], + + ['args_schema'] = {ts.one_of{'hex', 'base32', 'base64'}:is_optional(), + ts.one_of{'blake2', 'sha256', 'sha1', 'sha512', 'md5'}:is_optional()} + + }, + -- Get all attachments files + ['files'] = { + ['get_value'] = function(task) + local parts = task:get_parts() or E + local files = {} + + for _,p in ipairs(parts) do + local fname = p:get_filename() + if fname then + table.insert(files, fname) + end + end + + if #files > 0 then + return files,'string_list' + end + + return nil + end, + ['description'] = 'Get all attachments files', + }, + -- Get languages for text parts + ['languages'] = { + ['get_value'] = function(task) + local text_parts = task:get_text_parts() or E + local languages = {} + + for _,p in ipairs(text_parts) do + local lang = p:get_language() + if lang then + table.insert(languages, lang) + end + end + + if #languages > 0 then + return languages,'string_list' + end + + return nil + end, + ['description'] = 'Get languages for text parts', + }, + -- Get helo value + ['helo'] = { + ['get_value'] = function(task) + return task:get_helo(),'string' + end, + ['description'] = 'Get helo value', + }, + -- Get header with the name that is expected as an argument. Returns list of + -- headers with this name + ['header'] = { + ['get_value'] = function(task, args) + local strong = false + if args[2] then + if args[2]:match('strong') then + strong = true + end + + if args[2]:match('full') then + return task:get_header_full(args[1], strong),'table_list' + end + + return task:get_header(args[1], strong),'string' + else + return task:get_header(args[1]),'string' + end + end, + ['description'] = [[Get header with the name that is expected as an argument. +The optional second argument accepts list of flags: + - `full`: returns all headers with this name with all data (like task:get_header_full()) + - `strong`: use case sensitive match when matching header's name]], + ['args_schema'] = {ts.string, + (ts.pattern("strong") + ts.pattern("full")):is_optional()} + }, + -- Get list of received headers (returns list of tables) + ['received'] = { + ['get_value'] = function(task, args) + local rh = task:get_received_headers() + if args[1] and rh then + return fun.map(function(r) return r[args[1]] end, rh), 'string_list' + end + + return rh,'table_list' + end, + ['description'] = [[Get list of received headers. +If no arguments specified, returns list of tables. Otherwise, selects a specific element, +e.g. `by_hostname`]], + }, + -- Get all urls + ['urls'] = { + ['get_value'] = function(task, args) + local urls = task:get_urls() + if args[1] and urls then + return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list' + end + return urls,'userdata_list' + end, + ['description'] = [[Get list of all urls. +If no arguments specified, returns list of url objects. Otherwise, calls a specific method, +e.g. `get_tld`]], + }, + -- Get all emails + ['emails'] = { + ['get_value'] = function(task, args) + local urls = task:get_emails() + if args[1] and urls then + return fun.map(function(r) return r[args[1]](r) end, urls), 'string_list' + end + return urls,'userdata_list' + end, + ['description'] = [[Get list of all emails. +If no arguments specified, returns list of url objects. Otherwise, calls a specific method, +e.g. `get_user`]], + }, + -- Get specific pool var. The first argument must be variable name, + -- the second argument is optional and defines the type (string by default) + ['pool_var'] = { + ['get_value'] = function(task, args) + local type = args[2] or 'string' + return task:get_mempool():get_variable(args[1], type),(type) + end, + ['description'] = [[Get specific pool var. The first argument must be variable name, +the second argument is optional and defines the type (string by default)]], + ['args_schema'] = {ts.string, ts.string:is_optional()} + }, + -- Get specific HTTP request header. The first argument must be header name. + ['request_header'] = { + ['get_value'] = function(task, args) + local hdr = task:get_request_header(args[1]) + if hdr then + return tostring(hdr),'string' + end + + return nil + end, + ['description'] = [[Get specific HTTP request header. +The first argument must be header name.]], + ['args_schema'] = {ts.string} + }, + -- Get task date, optionally formatted + ['time'] = { + ['get_value'] = function(task, args) + local what = args[1] or 'message' + local dt = task:get_date{format = what, gmt = true} + + if dt then + if args[2] then + -- Should be in format !xxx, as dt is in GMT + return os.date(args[2], dt),'string' + end + + return tostring(dt),'string' + end + + return nil + end, + ['description'] = [[Get task timestamp. The first argument is type: + - `connect`: connection timestamp (default) + - `message`: timestamp as defined by `Date` header + + The second argument is optional time format, see [os.date](http://pgl.yoyo.org/luai/i/os.date) description]], + ['args_schema'] = {ts.one_of{'connect', 'message'}:is_optional(), + ts.string:is_optional()} + }, + -- Get text words from a message + ['words'] = { + ['get_value'] = function(task, args) + local how = args[1] or 'stem' + local tp = task:get_text_parts() + + if tp then + local rtype = 'string_list' + if how == 'full' then + rtype = 'table_list' + end + + return lua_util.flatten( + fun.map(function(p) + return p:get_words(how) + end, tp)), rtype + end + + return nil + end, + ['description'] = [[Get words from text parts + - `stem`: stemmed words (default) + - `raw`: raw words + - `norm`: normalised words (lowercased) + - `full`: list of tables + ]], + ['args_schema'] = { ts.one_of { 'stem', 'raw', 'norm', 'full' }:is_optional()}, + }, +} + +return extractors \ No newline at end of file diff --git a/lualib/lua_selectors/init.lua b/lualib/lua_selectors/init.lua new file mode 100644 index 000000000..356a05509 --- /dev/null +++ b/lualib/lua_selectors/init.lua @@ -0,0 +1,496 @@ +--[[ +Copyright (c) 2018, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- This module contains 'selectors' implementation: code to extract data +-- from Rspamd tasks and compose those together +-- +-- Read more at https://rspamd.com/doc/configuration/selectors.html + +--[[[ +-- @module lua_selectors +-- This module contains 'selectors' implementation: code to extract data +-- from Rspamd tasks and compose those together. +-- Typical selector looks like this: header(User).lower.substring(1, 2):ip +--]] + +local exports = { + maps = require "lua_selectors/maps" +} + +local logger = require 'rspamd_logger' +local fun = require 'fun' +local lua_util = require "lua_util" +local M = "selectors" +local E = {} + +local extractors = require "lua_scanners/extractors" +local transform_function = require "lua_scanners/transforms" + +local function pure_type(ltype) + return ltype:match('^(.*)_list$') +end + +local function process_selector(task, sel) + local function allowed_type(t) + if t == 'string' or t == 'text' or t == 'string_list' or t == 'text_list' then + return true + end + + return false + end + + local function list_type(t) + return pure_type(t) + end + + local function implicit_tostring(t, ud_or_table) + if t == 'table' then + -- Table (very special) + if ud_or_table.value then + return ud_or_table.value,'string' + elseif ud_or_table.addr then + return ud_or_table.addr,'string' + end + + return logger.slog("%s", ud_or_table),'string' + else + return tostring(ud_or_table),'string' + end + end + + local input,etype = sel.selector.get_value(task, sel.selector.args) + + if not input then + lua_util.debugm(M, task, 'no value extracted for %s', sel.selector.name) + return nil + end + + lua_util.debugm(M, task, 'extracted %s, type %s', + sel.selector.name, etype) + + local pipe = sel.processor_pipe or E + + if etype:match('^userdata') or etype:match('^table') then + -- Apply userdata conversion first + local first_elt = pipe[1] + + if first_elt and first_elt.method then + -- Explicit conversion + local meth = first_elt + + if meth.types[etype] then + lua_util.debugm(M, task, 'apply method `%s` to %s', + meth.name, etype) + input,etype = meth.process(input, etype) + else + local pt = pure_type(etype) + + if meth.types[pt] then + lua_util.debugm(M, task, 'map method `%s` to list of %s', + meth.name, pt) + -- Map method to a list of inputs, excluding empty elements + input = fun.filter(function(map_elt) return map_elt end, + fun.map(function(list_elt) + local ret, _ = meth.process(list_elt, pt) + return ret + end, input)) + etype = 'string_list' + end + end + -- Remove method from the pipeline + pipe = fun.drop_n(1, pipe) + else + -- Implicit conversion + + local pt = pure_type(etype) + + if not pt then + lua_util.debugm(M, task, 'apply implicit conversion %s->string', etype) + input = implicit_tostring(etype, input) + etype = 'string' + else + lua_util.debugm(M, task, 'apply implicit map %s->string', pt) + input = fun.filter(function(map_elt) return map_elt end, + fun.map(function(list_elt) + local ret = implicit_tostring(pt, list_elt) + return ret + end, input)) + etype = 'string_list' + end + end + end + + -- Now we fold elements using left fold + local function fold_function(acc, x) + if acc == nil or acc[1] == nil then + lua_util.debugm(M, task, 'do not apply %s, accumulator is nil', x.name) + return nil + end + + local value = acc[1] + local t = acc[2] + + if not x.types[t] then + local pt = pure_type(t) + + if pt and x.types['list'] then + -- Generic list processor + lua_util.debugm(M, task, 'apply list function `%s` to %s', x.name, t) + return {x.process(value, t, x.args)} + elseif pt and x.map_type and x.types[pt] then + local map_type = x.map_type .. '_list' + lua_util.debugm(M, task, 'map `%s` to list of %s resulting %s', + x.name, pt, map_type) + -- Apply map, filtering empty values + return { + fun.filter(function(map_elt) return map_elt end, + fun.map(function(list_elt) + if not list_elt then return nil end + local ret, _ = x.process(list_elt, pt, x.args) + return ret + end, value)), + map_type -- Returned type + } + end + logger.errx(task, 'cannot apply transform %s for type %s', x.name, t) + return nil + end + + lua_util.debugm(M, task, 'apply %s to %s', x.name, t) + return {x.process(value, t, x.args)} + end + + local res = fun.foldl(fold_function, + {input, etype}, + pipe) + + if not res or not res[1] then return nil end -- Pipeline failed + + if not allowed_type(res[2]) then + + -- Search for implicit conversion + local pt = pure_type(res[2]) + + if pt then + lua_util.debugm(M, task, 'apply implicit map %s->string_list', pt) + res[1] = fun.map(function(e) return implicit_tostring(pt, e) end, res[1]) + res[2] = 'string_list' + else + res[1] = implicit_tostring(res[2], res[1]) + res[2] = 'string' + end + end + + if list_type(res[2]) then + -- Convert to table as it might have a functional form + res[1] = fun.totable(res[1]) + end + + lua_util.debugm(M, task, 'final selector type: %s, value: %s', res[2], res[1]) + + return res[1] +end + +local function make_grammar() + local l = require "lpeg" + local spc = l.S(" \t\n")^0 + local atom = l.C((l.R("az") + l.R("AZ") + l.R("09") + l.S("_-"))^1) + local singlequoted_string = l.P "'" * l.C(((1 - l.S "'\r\n\f\\") + (l.P'\\' * 1))^0) * "'" + local doublequoted_string = l.P '"' * l.C(((1 - l.S'"\r\n\f\\') + (l.P'\\' * 1))^0) * '"' + local argument = atom + singlequoted_string + doublequoted_string + local dot = l.P(".") + local semicolon = l.P(":") + local obrace = "(" * spc + local ebrace = spc * ")" + local comma = spc * "," * spc + local sel_separator = spc * l.S";*" * spc + + return l.P{ + "LIST"; + LIST = l.Ct(l.V("EXPR")) * (sel_separator * l.Ct(l.V("EXPR")))^0, + EXPR = l.V("FUNCTION") * (semicolon * l.V("METHOD"))^-1 * (dot * l.V("PROCESSOR"))^0, + PROCESSOR = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), + FUNCTION = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), + METHOD = l.Ct(atom / function(e) return '__' .. e end * spc * (obrace * l.V("ARG_LIST") * ebrace)^0), + ARG_LIST = l.Ct((argument * comma^0)^0) + } +end + +local parser = make_grammar() + +--[[[ +-- @function lua_selectors.parse_selector(cfg, str) +--]] +exports.parse_selector = function(cfg, str) + local parsed = {parser:match(str)} + local output = {} + + if not parsed or not parsed[1] then return nil end + + local function check_args(name, schema, args) + if schema then + if getmetatable(schema) then + -- Schema covers all arguments + local res,err = schema:transform(args) + if not res then + logger.errx(rspamd_config, 'invalid arguments for %s: %s', name, err) + return false + else + for i,elt in ipairs(res) do + args[i] = elt + end + end + else + for i,selt in ipairs(schema) do + local res,err = selt:transform(args[i]) + + if err then + logger.errx(rspamd_config, 'invalid arguments for %s: %s', name, err) + return false + else + args[i] = res + end + end + end + end + + return true + end + + -- Output AST format is the following: + -- table of individual selectors + -- each selector: list of functions + -- each function: function name + optional list of arguments + for _,sel in ipairs(parsed) do + local res = { + selector = {}, + processor_pipe = {}, + } + + local selector_tbl = sel[1] + if not selector_tbl then + logger.errx(cfg, 'no selector represented') + return nil + end + if not extractors[selector_tbl[1]] then + logger.errx(cfg, 'selector %s is unknown', selector_tbl[1]) + return nil + end + + res.selector = lua_util.shallowcopy(extractors[selector_tbl[1]]) + res.selector.name = selector_tbl[1] + res.selector.args = selector_tbl[2] or E + + if not check_args(res.selector.name, + res.selector.args_schema, + res.selector.args) then + return nil + end + + lua_util.debugm(M, cfg, 'processed selector %s, args: %s', + res.selector.name, res.selector.args) + + local pipeline_error = false + -- Now process processors pipe + fun.each(function(proc_tbl) + local proc_name = proc_tbl[1] + + if proc_name:match('^__') then + -- Special case - method + local method_name = proc_name:match('^__(.*)$') + -- Check array indexing... + if tonumber(method_name) then + method_name = tonumber(method_name) + end + local processor = { + name = tostring(method_name), + method = true, + args = proc_tbl[2] or E, + types = { + userdata = true, + table = true, + }, + map_type = 'string', + process = function(inp, t, args) + if t == 'userdata' then + return inp[method_name](inp, args),'string' + else + -- Table + return inp[method_name],'string' + end + end, + } + lua_util.debugm(M, cfg, 'attached method %s to selector %s, args: %s', + proc_name, res.selector.name, processor.args) + table.insert(res.processor_pipe, processor) + else + + if not transform_function[proc_name] then + logger.errx(cfg, 'processor %s is unknown', proc_name) + pipeline_error = proc_name + return nil + end + local processor = lua_util.shallowcopy(transform_function[proc_name]) + processor.name = proc_name + processor.args = proc_tbl[2] or E + + if not check_args(processor.name, processor.args_schema, processor.args) then + pipeline_error = 'args schema for ' .. proc_name + return nil + end + + lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s', + proc_name, res.selector.name, processor.args) + table.insert(res.processor_pipe, processor) + end + end, fun.tail(sel)) + + if pipeline_error then + logger.errx(cfg, 'unknown or invalid processor used: "%s", exiting', pipeline_error) + return nil + end + + table.insert(output, res) + end + + return output +end + +--[[[ +-- @function lua_selectors.register_extractor(cfg, name, selector) +--]] +exports.register_extractor = function(cfg, name, selector) + if selector.get_value then + if extractors[name] then + logger.warnx(cfg, 'redefining selector %s', name) + end + extractors[name] = selector + + return true + end + + logger.errx(cfg, 'bad selector %s', name) + return false +end + +--[[[ +-- @function lua_selectors.register_transform(cfg, name, transform) +--]] +exports.register_transform = function(cfg, name, transform) + if transform.process and transform.types then + if transform_function[name] then + logger.warnx(cfg, 'redefining transform function %s', name) + end + transform_function[name] = transform + + return true + end + + logger.errx(cfg, 'bad transform function %s', name) + return false +end + +--[[[ +-- @function lua_selectors.process_selectors(task, selectors_pipe) +--]] +exports.process_selectors = function(task, selectors_pipe) + local ret = {} + + for _,sel in ipairs(selectors_pipe) do + local r = process_selector(task, sel) + + -- If any element is nil, then the whole selector is nil + if not r then return nil end + table.insert(ret, r) + end + + return ret +end + +--[[[ +-- @function lua_selectors.combine_selectors(task, selectors, delimiter) +--]] +exports.combine_selectors = function(_, selectors, delimiter) + if not delimiter then delimiter = '' end + + if not selectors then return nil end + + local all_strings = fun.all(function(s) return type(s) == 'string' end, selectors) + + if all_strings then + return table.concat(selectors, delimiter) + else + -- We need to do a spill on each table selector + -- e.g. s:tbl:s -> s:telt1:s + s:telt2:s ... + local tbl = {} + local res = {} + + for i,s in ipairs(selectors) do + if type(s) == 'string' then + rawset(tbl, i, fun.duplicate(s)) + elseif type(s) == 'userdata' then + rawset(tbl, i, fun.duplicate(tostring(s))) + else + rawset(tbl, i, s) + end + end + + fun.each(function(...) + table.insert(res, table.concat({...}, delimiter)) + end, fun.zip(lua_util.unpack(tbl))) + + return res + end +end + +--[[[ +-- @function lua_selectors.create_closure(cfg, selector_str, delimiter='') +--]] +exports.create_selector_closure = function(cfg, selector_str, delimiter) + local selector = exports.parse_selector(cfg, selector_str) + + if not selector then + return nil + end + + return function(task) + local res = exports.process_selectors(task, selector) + + if res then + return exports.combine_selectors(nil, res, delimiter) + end + + return nil + end +end + +local function display_selectors(tbl) + return fun.tomap(fun.map(function(k,v) + return k, fun.tomap(fun.filter(function(kk, vv) + return type(vv) ~= 'function' + end, v)) + end, tbl)) +end + +exports.list_extractors = function() + return display_selectors(extractors) +end + +exports.list_transforms = function() + return display_selectors(transform_function) +end + +return exports diff --git a/lualib/lua_selectors/maps.lua b/lualib/lua_selectors/maps.lua new file mode 100644 index 000000000..175e23960 --- /dev/null +++ b/lualib/lua_selectors/maps.lua @@ -0,0 +1,19 @@ +--[[ +Copyright (c) 2019, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local maps = {} -- Shared within selectors, indexed by name + +return maps \ No newline at end of file diff --git a/lualib/lua_selectors/transforms.lua b/lualib/lua_selectors/transforms.lua new file mode 100644 index 000000000..99b4201ef --- /dev/null +++ b/lualib/lua_selectors/transforms.lua @@ -0,0 +1,395 @@ +--[[ +Copyright (c) 2019, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local fun = require 'fun' +local lua_util = require "lua_util" +local ts = require("tableshape").types +local logger = require 'rspamd_logger' +local M = "selectors" + +local maps = require "lua_selectors/maps" + +local function pure_type(ltype) + return ltype:match('^(.*)_list$') +end + +local transform_function = { + -- Returns the lowercased string + ['lower'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _) + return inp:lower(),'string' + end, + ['description'] = 'Returns the lowercased string', + }, + -- Returns the first element + ['first'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t) + return fun.head(inp),pure_type(t) + end, + ['description'] = 'Returns the first element', + }, + -- Returns the last element + ['last'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t) + return fun.nth(#inp, inp),pure_type(t) + end, + ['description'] = 'Returns the last element', + }, + -- Returns the nth element + ['nth'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t, args) + return fun.nth(args[1] or 1, inp),pure_type(t) + end, + ['description'] = 'Returns the nth element', + ['args_schema'] = {ts.number + ts.string / tonumber} + }, + ['take_n'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t, args) + return fun.take_n(args[1] or 1, inp),t + end, + ['description'] = 'Returns the n first elements', + ['args_schema'] = {ts.number + ts.string / tonumber} + }, + ['drop_n'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t, args) + return fun.drop_n(args[1] or 1, inp),t + end, + ['description'] = 'Returns list without the first n elements', + ['args_schema'] = {ts.number + ts.string / tonumber} + }, + -- Joins strings into a single string using separator in the argument + ['join'] = { + ['types'] = { + ['string_list'] = true + }, + ['process'] = function(inp, _, args) + return table.concat(fun.totable(inp), args[1] or ''), 'string' + end, + ['description'] = 'Joins strings into a single string using separator in the argument', + ['args_schema'] = {ts.string:is_optional()} + }, + -- Sort strings + ['sort'] = { + ['types'] = { + ['list'] = true + }, + ['process'] = function(inp, t, _) + table.sort(inp) + return inp, t + end, + ['description'] = 'Sort strings lexicographically', + }, + -- Return unique elements based on hashing (can work without sorting) + ['uniq'] = { + ['types'] = { + ['list'] = true + }, + ['process'] = function(inp, t, _) + local tmp = {} + fun.each(function(val) + tmp[val] = true + end, inp) + + return fun.map(function(k, _) return k end, tmp), t + end, + ['description'] = 'Returns a list of unique elements (using a hash table)', + }, + -- Create a digest from string or a list of strings + ['digest'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'hash', + ['process'] = function(inp, _, args) + local hash = require 'rspamd_cryptobox_hash' + local encoding = args[1] or 'hex' + local ht = args[2] or 'blake2' + local h = hash:create_specific(ht):update(inp) + local s + + if encoding == 'hex' then + s = h:hex() + elseif encoding == 'base32' then + s = h:base32() + elseif encoding == 'base64' then + s = h:base64() + end + + return s,'string' + end, + ['description'] = [[Create a digest from a string. +The first argument is encoding (`hex`, `base32`, `base64`), +the second argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]], + ['args_schema'] = {ts.one_of{'hex', 'base32', 'base64'}:is_optional(), + ts.one_of{'blake2', 'sha256', 'sha1', 'sha512', 'md5'}:is_optional()} + }, + -- Extracts substring + ['substring'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local start_pos = args[1] or 1 + local end_pos = args[2] or -1 + + return inp:sub(start_pos, end_pos), 'string' + end, + ['description'] = 'Extracts substring; the first argument is start, the second is the last (like in Lua)', + ['args_schema'] = {(ts.number + ts.string / tonumber):is_optional(), + (ts.number + ts.string / tonumber):is_optional()} + }, + -- Prepends a string or a strings list + ['prepend'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local prepend = table.concat(args, '') + + return prepend .. inp, 'string' + end, + ['description'] = 'Prepends a string or a strings list', + }, + -- Appends a string or a strings list + ['append'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local append = table.concat(args, '') + + return inp .. append, 'string' + end, + ['description'] = 'Appends a string or a strings list', + }, + -- Regexp matching + ['regexp'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local rspamd_regexp = require "rspamd_regexp" + + local re = rspamd_regexp.create_cached(args[1]) + + if not re then + logger.errx('invalid regexp: %s', args[1]) + return nil + end + + local res = re:search(inp, false, true) + + if res then + if #res == 1 then + return res[1],'string' + end + + return res,'string_list' + end + + return nil + end, + ['description'] = 'Regexp matching', + ['args_schema'] = {ts.string} + }, + -- Returns a value if it exists in some map (or acts like a `filter` function) + ['filter_map'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + local map = maps[args[1]] + + if not map then + logger.errx('invalid map name: %s', args[1]) + return nil + end + + local res = map:get_key(inp) + + if res then + return inp,t + end + + return nil + end, + ['description'] = 'Returns a value if it exists in some map (or acts like a `filter` function)', + ['args_schema'] = {ts.string} + }, + -- Returns a value from some map corresponding to some key (or acts like a `map` function) + ['apply_map'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + local map = maps[args[1]] + + if not map then + logger.errx('invalid map name: %s', args[1]) + return nil + end + + local res = map:get_key(inp) + + if res then + return res,t + end + + return nil + end, + ['description'] = 'Returns a value from some map corresponding to some key (or acts like a `map` function)', + ['args_schema'] = {ts.string} + }, + -- Drops input value and return values from function's arguments or an empty string + ['id'] = { + ['types'] = { + ['string'] = true, + ['list'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(_, _, args) + if args[1] and args[2] then + return fun.map(tostring, args),'string_list' + elseif args[1] then + return args[1],'string' + end + + return '','string' + end, + ['description'] = 'Drops input value and return values from function\'s arguments or an empty string', + ['args_schema'] = (ts.string + ts.array_of(ts.string)):is_optional() + }, + ['equal'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + if inp == args[1] then + return inp,'string' + end + + return nil + end, + ['description'] = [[Boolean function equal. +Returns either nil or its argument if input is equal to argument]], + ['args_schema'] = {ts.string} + }, + -- Boolean function in, returns either nil or its input if input is in args list + ['in'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + for _,a in ipairs(args) do if a == inp then return inp,t end end + return nil + end, + ['description'] = [[Boolean function in. +Returns either nil or its input if input is in args list]], + ['args_schema'] = ts.array_of(ts.string) + }, + ['not_in'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + for _,a in ipairs(args) do if a == inp then return nil end end + return inp,t + end, + ['description'] = [[Boolean function not in. +Returns either nil or its input if input is not in args list]], + ['args_schema'] = ts.array_of(ts.string) + }, + ['inverse'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + if inp then + return nil + else + return (args[1] or 'true'),'string' + end + end, + ['description'] = [[Inverses input. +Empty string comes the first argument or 'true', non-empty string comes nil]], + ['args_schema'] = {ts.string:is_optional()} + }, + ['ipmask'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local rspamd_ip = require "rspamd_ip" + -- Non optimal: convert string to an IP address + local ip = rspamd_ip.from_string(inp) + + if not ip or not ip:is_valid() then + lua_util.debugm(M, "cannot convert %s to IP", inp) + return nil + end + + if ip:get_version() == 4 then + local mask = tonumber(args[1]) + + return ip:apply_mask(mask):to_string(),'string' + else + -- IPv6 takes the second argument or the first one... + local mask_str = args[2] or args[1] + local mask = tonumber(mask_str) + + return ip:apply_mask(mask):to_string(),'string' + end + end, + ['description'] = 'Applies mask to IP address.' .. + ' The first argument is the mask for IPv4 addresses, the second is the mask for IPv6 addresses.', + ['args_schema'] = {(ts.number + ts.string / tonumber), + (ts.number + ts.string / tonumber):is_optional()} + }, +} + +transform_function.match = transform_function.regexp + +return transform_function \ No newline at end of file diff --git a/test/lua/unit/selectors.combined.lua b/test/lua/unit/selectors.combined.lua index d70caeb10..06c5c0fca 100644 --- a/test/lua/unit/selectors.combined.lua +++ b/test/lua/unit/selectors.combined.lua @@ -33,25 +33,25 @@ context("Selectors test", function() local cases = { ["rcpts + weekend"] = { - selector = "rcpts:addr.take_n(5).lower;time('message', '!%w').in(6, 7).id('weekends')", + selector = "rcpts:addr.take_n(5).lower;time('message', '!%w').in(6, 7).id('weekends')", expect = { - "nobody@example.com:weekends", + "nobody@example.com:weekends", "no-one@example.com:weekends"}}, ["weekend + rcpts"] = { - selector = "time('message', '!%w').in(6, 7).id('weekends');rcpts:addr.take_n(5).lower", + selector = "time('message', '!%w').in(6, 7).id('weekends');rcpts:addr.take_n(5).lower", expect = { - "weekends:nobody@example.com", + "weekends:nobody@example.com", "weekends:no-one@example.com"}}, ["id(rcpt) + rcpts + weekend"] = { - selector = "id('rcpt');rcpts:addr.take_n(5).lower;time('message', '!%w').in(6, 7).id('weekends')", + selector = "id('rcpt');rcpts:addr.take_n(5).lower;time('message', '!%w').in(6, 7).id('weekends')", expect = { "rcpt:nobody@example.com:weekends", "rcpt:no-one@example.com:weekends"}}, ["id(rcpt) + id(2) rcpts + weekend"] = { - selector = "id('rcpt'); id(2); rcpts:addr.take_n(5).lower; time('message', '!%w').in(6, 7).id('weekends')", + selector = "id('rcpt'); id(2); rcpts:addr.take_n(5).lower; time('message', '!%w').in(6, 7).id('weekends')", expect = { "rcpt:2:nobody@example.com:weekends", "rcpt:2:no-one@example.com:weekends"}}, @@ -59,7 +59,7 @@ context("Selectors test", function() -- There are two rcpts but only one url in the message -- resulting table size is the size of the smallest table ["id(rcpt) + id(2) + rcpts and urls + weekend"] = { - selector = "id('rcpt'); id(2); rcpts:addr.take_n(5).lower; id('urls'); urls:get_host; time('message', '!%w').in(6, 7).id('weekends')", + selector = "id('rcpt'); id(2); rcpts:addr.take_n(5).lower; id('urls'); urls:get_host; time('message', '!%w').in(6, 7).id('weekends')", expect = { "rcpt:2:nobody@example.com:urls:example.net:weekends"}}, } diff --git a/test/lua/unit/selectors.negative.lua b/test/lua/unit/selectors.negative.lua index c7d3f1e7a..426240066 100644 --- a/test/lua/unit/selectors.negative.lua +++ b/test/lua/unit/selectors.negative.lua @@ -5,9 +5,9 @@ context("Selectors test", function() local lua_selectors = require "lua_selectors" local ffi = require "ffi" local cfg = rspamd_config - + local task - + ffi.cdef[[ void rspamd_url_init (const char *tld_file); ]]