]> source.dussan.org Git - rspamd.git/commitdiff
Completely rewrite multimap plugin.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 23 Feb 2015 17:42:29 +0000 (17:42 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 23 Feb 2015 17:42:29 +0000 (17:42 +0000)
src/plugins/lua/multimap.lua

index 6af38a40cea7ed7d1a95627cf67effc34cc66b84..a5624979cbb9072f29c330e6dc3b95d8b165f02e 100644 (file)
@@ -28,237 +28,173 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 local rules = {}
 local rspamd_logger = require "rspamd_logger"
-local rspamd_cdb = require "rspamd_cdb"
+local cdb = require "rspamd_cdb"
+local _ = require "fun"
+--local dumper = require 'pl.pretty'.dump
 
 local function ip_to_rbl(ip, rbl)
-       return table.concat(ip:inversed_str_octets(), ".") .. '.' .. rbl
+  return table.concat(ip:inversed_str_octets(), ".") .. '.' .. rbl
 end
 
 local function check_multimap(task)
-       local function multimap_rbl_cb(resolver, to_resolve, results, err, rbl)
-               task:inc_dns_req()
-               if results then
-                       -- Get corresponding rule by rbl name
-                       for _,rule in pairs(rules) do
-                               if rule == rbl then
-                                       task:insert_result(rule['symbol'], 1, rule['map'])
-                                       return
-                               end
-                       end
-               end
-       end
+  -- Generate dns callback closure
+  local function dns_cb_generator(r)
+    local cb = function (resolver, to_resolve, results, err, rbl)
+      task:inc_dns_req()
+      if results then
+        task:insert_result(r['symbol'], 1, r['map'])
+      end
+    end
+    return cb
+  end
 
-       for _,rule in pairs(rules) do
-               if rule['type'] == 'ip' then
-                       if rule['cdb'] then
-                               local ip = task:get_from_ip()
-                               if ip:is_valid() and rule['hash']:lookup(ip:to_string()) then
-                                       task:insert_result(rule['symbol'], 1)
-                               end
-                       else
-                               local ip = task:get_from_ip()
-                               if ip:is_valid() and rule['ips'] and rule['ips']:get_key(ip) then
-                                       task:insert_result(rule['symbol'], 1)
-                               end
-                       end
-               elseif rule['type'] == 'header' then
-                       local headers = task:get_header_full(rule['header'])
-                       if headers then
-                               for _,hv in ipairs(headers) do
-                                       if rule['pattern'] then
-                                               -- extract a part from header
-                                               local _,_,ext = string.find(hv['decoded'], rule['pattern'])
-                                               if ext then
-                                                       if rule['cdb'] then
-                                                               if rule['hash']:lookup(ext) then
-                                                                       task:insert_result(rule['symbol'], 1)
-                                                               end
-                                                       else
-                                                               if rule['hash']:get_key(ext) then
-                                                                       task:insert_result(rule['symbol'], 1)
-                                                               end
-                                                       end
-                                               end
-                                       else
-                                               if rule['cdb'] then
-                                                       if rule['hash']:lookup(hv['decoded']) then
-                                                               task:insert_result(rule['symbol'], 1)
-                                                       end
-                                               else
-                                                       if rule['hash']:get_key(hv['decoded']) then
-                                                               task:insert_result(rule['symbol'], 1)
-                                                       end
-                                               end
-                                       end
-                               end
-                       end
-               elseif rule['type'] == 'dnsbl' then
-                       local ip = task:get_from_ip()
-                       if ip:is_valid() then
-                               if ip:get_version() == 6 and rule['ipv6'] then
-                                       task:get_resolver():resolve_a(task:get_session(), task:get_mempool(),
-                                               ip_to_rbl(ip, rule['map']), multimap_rbl_cb, rule['map'])
-                               elseif ip:get_version() == 4 then
-                                       task:get_resolver():resolve_a(task:get_session(), task:get_mempool(),
-                                               ip_to_rbl(ip, rule['map']), multimap_rbl_cb, rule['map'])
-                               end
-                       end
-               elseif rule['type'] == 'rcpt' then
-                       -- First try to get rcpt field
-                       local rcpts = task:get_recipients()
-                       if rcpts then
-                               for _,r in ipairs(rcpts) do
-                                       if r['addr'] then
-                                               if rule['pattern'] then
-                                                       -- extract a part from header
-                                                       local _,_,ext = string.find(r['addr'], rule['pattern'])
-                                                       if ext then
-                                                               if rule['cdb'] then
-                                                                       if rule['hash']:lookup(ext) then
-                                                                               task:insert_result(rule['symbol'], 1)
-                                                                       end
-                                                               else
-                                                                       if rule['hash']:get_key(ext) then
-                                                                               task:insert_result(rule['symbol'], 1)
-                                                                       end
-                                                               end
-                                                       end
-                                               else
-                                                       if rule['cdb'] then
-                                                               if rule['hash']:lookup(r['addr']) then
-                                                                       task:insert_result(rule['symbol'], 1)
-                                                               end
-                                                       else
-                                                               if rule['hash']:get_key(r['addr']) then
-                                                                       task:insert_result(rule['symbol'], 1)
-                                                               end
-                                                       end
-                                               end
-                                       end     
-                               end
-                       end
-               elseif rule['type'] == 'from' then
-                       -- First try to get from field
-                       local from = task:get_from()
-                       if from then
-                               for _,r in ipairs(from) do
-                                       if r['addr'] then
-                                               if rule['pattern'] then
-                                                       -- extract a part from header
-                                                       local _,_,ext = string.find(r['addr'], rule['pattern'])
-                                                       if ext then
-                                                               if rule['cdb'] then
-                                                                       if rule['hash']:lookup(ext) then
-                                                                               task:insert_result(rule['symbol'], 1)
-                                                                       end
-                                                               else
-                                                                       if rule['hash']:get_key(ext) then
-                                                                               task:insert_result(rule['symbol'], 1)
-                                                                       end
-                                                               end
-                                                       end
-                                               else
-                                                       if rule['cdb'] then
-                                                               if rule['hash']:lookup(r['addr']) then
-                                                                       task:insert_result(rule['symbol'], 1)
-                                                               end
-                                                       else
-                                                               if rule['hash']:get_key(r['addr']) then
-                                                                       task:insert_result(rule['symbol'], 1)
-                                                               end
-                                                       end
-                                               end
-                                       end     
-                               end
-                       end
-               end
-       end
+  -- Match a single value for against a single rule
+  local function match_rule(r, value)
+    local ret = false
+    if r['cdb'] then
+      local srch = value
+      if r['type'] == 'ip' then
+        srch = value:to_string()
+      end
+
+      ret = r['cdb']:lookup(srch)
+    elseif r['radix'] then
+      ret = r['radix']:get_key(value)
+    elseif r['hash'] then
+      ret = r['hash']:get_key(value)
+    end
+
+    if ret then
+      task:insert_result(r['symbol'], 1)
+    end
+  end
+
+  -- Match list of values according to the field
+  local function match_list(r, ls, field)
+    if ls then
+      if field then
+        _.each(function(e) match_rule(r, e[field]) end, ls)
+      else
+        _.each(function(e) match_rule(r, e) end, ls)
+      end
+    end
+  end
+
+  -- IP rules
+  local ip = task:get_from_ip()
+  if ip:is_valid() then
+    _.each(function(r) match_rule(r, ip) end,
+      _.filter(function(r) return r['type'] == 'ip' end, rules))
+  end
+
+  -- Header rules
+  _.each(function(r)
+    local hv = task:get_header_full(r['header'])
+    match_list(r, hv, 'decoded')
+  end,
+  _.filter(function(r) return r['type'] == 'header' end, rules))
+
+  -- Rcpt rules
+  local rcpts = task:get_recipients()
+  if rcpts then
+    _.each(function(r)
+      match_list(r, rcpts, 'addr')
+    end,
+    _.filter(function(r) return r['type'] == 'rcpt' end, rules))
+  end
+
+  -- From rules
+  local from = task:get_from()
+  if from then
+    _.each(function(r)
+      match_list(r, from, 'addr')
+    end,
+    _.filter(function(r) return r['type'] == 'from' end, rules))
+  end
+
+  -- RBL rules
+  if ip:is_valid() then
+    _.each(function(r)
+      task:get_resolver():resolve_a(task:get_session(), task:get_mempool(),
+        ip_to_rbl(ip, r['map']), dns_cb_generator(r))
+    end,
+    _.filter(function(r) return r['type'] == 'dnsbl' end, rules))
+  end
 end
 
 local function add_multimap_rule(key, newrule)
-       if not newrule['map'] then
-               rspamd_logger.err('incomplete rule')
-               return nil
-       end
-       if not newrule['symbol'] and key then
-               newrule['symbol'] = key
-       elseif not newrule['symbol'] then
-               rspamd_logger.err('incomplete rule')
-               return nil
-       end
-       -- Check cdb flag
-       if string.find(newrule['map'], '^cdb://.*$') then
-               local test = cdb.create(newrule['map'])
-               newrule['hash'] = cdb.create(newrule['map'])
-               newrule['cdb'] = true
-               if newrule['hash'] then
-                       table.insert(rules, newrule)
-                       return newrule
-               else
-                       rspamd_logger.warn('Cannot add rule: map doesn\'t exists: ' .. newrule['map'])
-               end
-       else
-               if newrule['type'] == 'ip' then
-                       newrule['ips'] = rspamd_config:add_radix_map (newrule['map'], newrule['description'])
-                       if newrule['ips'] then
-                               table.insert(rules, newrule)
-                               return newrule
-                       else
-                               rspamd_logger.warn('Cannot add rule: map doesn\'t exists: ' .. newrule['map'])
-                       end
-               elseif newrule['type'] == 'header' or newrule['type'] == 'rcpt' or newrule['type'] == 'from' then
-                       newrule['hash'] = rspamd_config:add_hash_map (newrule['map'], newrule['description'])
-                       if newrule['hash'] then
-                               table.insert(rules, newrule)
-                               return newrule
-                       else
-                               rspamd_logger.warn('Cannot add rule: map doesn\'t exists: ' .. newrule['map'])
-                       end
-               elseif newrule['type'] == 'cdb' then
-                       newrule['hash'] = rspamd_cdb.create(newrule['map'])
-                       if newrule['hash'] then
-                               table.insert(rules, newrule)
-                               return newrule
-                       else
-                               rspamd_logger.warn('Cannot add rule: map doesn\'t exists: ' .. newrule['map'])
-                       end
-               else
-                       table.insert(rules, newrule)
-                       return newrule
-               end
-       end
-       return nil
+  if not newrule['map'] then
+    rspamd_logger.err('incomplete rule')
+    return nil
+  end
+  if not newrule['symbol'] and key then
+    newrule['symbol'] = key
+  elseif not newrule['symbol'] then
+    rspamd_logger.err('incomplete rule')
+    return nil
+  end
+  -- Check cdb flag
+  if string.find(newrule['map'], '^cdb://.*$') then
+    local test = cdb.create(newrule['map'])
+    newrule['cdb'] = cdb.create(newrule['map'])
+    if newrule['cdb'] then
+      return newrule
+    else
+      rspamd_logger.warn('Cannot add rule: map doesn\'t exists: ' .. newrule['map'])
+    end
+  else
+    if newrule['type'] == 'ip' then
+      newrule['radix'] = rspamd_config:add_radix_map (newrule['map'], newrule['description'])
+      if newrule['radix'] then
+        table.insert(rules, newrule)
+        return newrule
+      else
+        rspamd_logger.warn('Cannot add rule: map doesn\'t exists: ' .. newrule['map'])
+      end
+    elseif newrule['type'] == 'header' or newrule['type'] == 'rcpt' or newrule['type'] == 'from' then
+      newrule['hash'] = rspamd_config:add_hash_map (newrule['map'], newrule['description'])
+      if newrule['hash'] then
+        return newrule
+      else
+        rspamd_logger.warn('Cannot add rule: map doesn\'t exists: ' .. newrule['map'])
+      end
+    elseif newrule['type'] == 'dnsbl' then
+      return newrule 
+    end
+  end
+  return nil
 end
 
 -- Registration
 if type(rspamd_config.get_api_version) ~= 'nil' then
-       if rspamd_config:get_api_version() >= 1 then
-               rspamd_config:register_module_option('multimap', 'rule', 'string')
-       end
+  if rspamd_config:get_api_version() >= 1 then
+    rspamd_config:register_module_option('multimap', 'rule', 'string')
+  end
 end
 
 local opts =  rspamd_config:get_all_opt('multimap')
 if opts and type(opts) == 'table' then
-       for k,m in pairs(opts) do
-               if type(m) == 'table' then
-                       local rule = add_multimap_rule(k, m)
-                       if not rule then
-                               rspamd_logger.err('cannot add rule: "'..k..'"')
-                       else
-                               if type(rspamd_config.get_api_version) ~= 'nil' then
-                                       rspamd_config:register_virtual_symbol(rule['symbol'], 1.0)
-                               end
-                       end
-               else
-                       rspamd_logger.err('parameter ' .. k .. ' is invalid, must be an object')
-               end
-       end
-       -- add fake symbol to check all maps inside a single callback
-       if type(rspamd_config.get_api_version) ~= 'nil' then
-               if rspamd_config.get_api_version() >= 4 then
-                       rspamd_config:register_callback_symbol_priority('MULTIMAP', 1.0, -1, check_multimap)
-               else
-                       rspamd_config:register_callback_symbol('MULTIMAP', 1.0, check_multimap)
-               end
-       end
+  for k,m in pairs(opts) do
+    if type(m) == 'table' then
+      local rule = add_multimap_rule(k, m)
+      if not rule then
+        rspamd_logger.err('cannot add rule: "'..k..'"')
+      else
+        table.insert(rules, rule)
+        if type(rspamd_config.get_api_version) ~= 'nil' then
+          rspamd_config:register_virtual_symbol(rule['symbol'], 1.0)
+        end
+      end
+    else
+      rspamd_logger.err('parameter ' .. k .. ' is invalid, must be an object')
+    end
+  end
+  -- add fake symbol to check all maps inside a single callback
+  if type(rspamd_config.get_api_version) ~= 'nil' then
+    if rspamd_config.get_api_version() >= 4 then
+      rspamd_config:register_callback_symbol_priority('MULTIMAP', 1.0, -1, check_multimap)
+    else
+      rspamd_config:register_callback_symbol('MULTIMAP', 1.0, check_multimap)
+    end
+  end
 end