summaryrefslogtreecommitdiffstats
path: root/lualib/lua_maps_expressions.lua
blob: bf3215d392c4487845425e90624f68cf4e43ba2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
--[[[
-- @module lua_maps_expressions
-- This module contains routines to combine maps, selectors and expressions
-- in a generic framework
@example
whitelist_ip_from = {
  rules {
    ip {
      selector = "ip";
      map = "/path/to/whitelist_ip.map";
    }
    from {
      selector = "from(smtp)";
      map = "/path/to/whitelist_from.map";
    }
  }
  expression = "ip & from";
}
--]]

--[[
Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--

local lua_selectors = require "lua_selectors"
local lua_maps = require "lua_maps"
local rspamd_expression = require "rspamd_expression"
local rspamd_logger = require "rspamd_logger"
local fun = require "fun"

local exports = {}

local function process_func(elt, task)
  local matched = {}
  local function process_atom(atom)
    local rule = elt.rules[atom]
    local res = 0

    local function match_rule(val)
      local map_match = rule.map:get_key(val)
      if map_match then
        res = 1.0
        matched[rule.name] = {
          matched = val,
          value = map_match
        }
      end
    end

    local values = rule.selector(task)

    if values then
      if type(values) == 'table' then
        for _,val in ipairs(values) do
          if res == 0 then
            match_rule(val)
          end
        end
      else
        match_rule(values)
      end
    end

    return res
  end

  local res = elt.expr:process(process_atom)

  if res then
    return res,matched
  end

  return nil
end

--[[[
-- @function lua_maps_expression.create(config, object, module_name)
-- Creates a new maps combination from `object` for `module_name`.
-- The input should be table with the following fields:
--
-- * `rules` - kv map of rules where each rule has `map` and `selector` mandatory attribute, also `type` for map type, e.g. `regexp`
-- * `expression` - Rspamd expression where elements are names from `rules` field, e.g. `ip & from`
--
-- This function returns an object with public method `process(task)` that checks
-- a task for the conditions defined in `expression` and `rules` and returns 2 values:
--
-- 1. value returned by an expression (e.g. 1 or 0)
-- 2. an map (rule_name -> table) of matches, where each element has the following fields:
--   * `matched` - selector's value
--   * `value` - map's result
--
-- In case if `expression` is false a `nil` value is returned.
-- @param {rspamd_config} cfg rspamd config
-- @param {table} obj configuration table
--
--]]
local function create(cfg, obj, module_name)
  if not module_name then module_name = 'lua_maps_expressions' end

  if not obj or not obj.rules or not obj.expression then
    rspamd_logger.errx(cfg, 'cannot add maps combination for module %s: required elements are missing',
        module_name)
    return nil
  end

  local ret = {
    process = process_func,
    rules = {},
    module_name = module_name
  }

  for name,rule in pairs(obj.rules) do
    local sel = lua_selectors.create_selector_closure(cfg, rule.selector)

    if not sel then
      rspamd_logger.errx(cfg, 'cannot add selector for element %s in module %s',
          name, module_name)
    end

    if not rule.type then
      -- Guess type
      if name:find('ip') or name:find('ipnet') then
        rule.type = 'radix'
      elseif name:find('regexp') or name:find('re_') then
        rule.type = 'regexp'
      elseif name:find('glob') then
        rule.type = 'regexp'
      else
        rule.type = 'set'
      end
    end
    local map = lua_maps.map_add_from_ucl(rule.map, rule.type,
        obj.description or module_name)
    if not map then
      rspamd_logger.errx(cfg, 'cannot add map for element %s in module %s',
          name, module_name)
    end

    if sel and map then
      ret.rules[name] = {
        selector = sel,
        map = map,
        name = name,
      }
    else
      return nil
    end
  end

  -- Now process and parse expression
  local function parse_atom(str)
    local atom = table.concat(fun.totable(fun.take_while(function(c)
      if string.find(', \t()><+!|&\n', c) then
        return false
      end
      return true
    end, fun.iter(str))), '')

    if ret.rules[atom] then
      return atom
    end

    rspamd_logger.errx(cfg, 'use of undefined element "%s" when parsing maps expression for %s',
        atom, module_name)

    return nil
  end
  local expr = rspamd_expression.create(obj.expression, parse_atom,
      rspamd_config:get_mempool())

  if not expr then
    rspamd_logger.errx(cfg, 'cannot add map expression for module %s',
        module_name)
    return nil
  end

  ret.expr = expr

  if obj.symbol then
    rspamd_config:register_symbol{
      type = 'virtual,ghost',
      name = obj.symbol,
      score = 0.0,
    }
  end

  ret.symbol = obj.symbol

  return ret
end

exports.create = create

return exports