]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] lua_scanners - move and extend mime_part matching
authorCarsten Rosenberg <c.rosenberg@heinlein-support.de>
Sun, 13 Jan 2019 13:25:14 +0000 (14:25 +0100)
committerCarsten Rosenberg <c.rosenberg@heinlein-support.de>
Sun, 13 Jan 2019 13:25:14 +0000 (14:25 +0100)
lualib/lua_scanners/common.lua
src/plugins/lua/antivirus.lua
src/plugins/lua/external_services.lua

index 6364a7e8e2e08e21fb488d272378e18645037ef3..1fe82fabb5c1061dac3e6c14dccf0be9bf7868b2 100644 (file)
@@ -1,5 +1,6 @@
 --[[
 Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
+Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -20,6 +21,7 @@ limitations under the License.
 --]]
 
 local rspamd_logger = require "rspamd_logger"
+local rspamd_regexp = require "rspamd_regexp"
 local lua_util = require "lua_util"
 local lua_redis = require "lua_redis"
 local fun = require "fun"
@@ -178,22 +180,119 @@ local function save_av_cache(task, digest, rule, to_save, dyn_weight)
   return false
 end
 
-local function text_parts_min_words(task, min_words)
-  local filter_func = function(p)
-    return p:get_words_count() >= min_words
+local function create_regex_table(task, patterns)
+  local regex_table = {}
+  if patterns[1] then
+    for i, p in ipairs(patterns) do
+      if type(p) == 'table' then
+        local new_set = {}
+        for k, v in pairs(p) do
+          new_set[k] = rspamd_regexp.create_cached(v)
+        end
+        regex_table[i] = new_set
+      else
+        regex_table[i] = {}
+      end
+    end
+  else
+    for k, v in pairs(patterns) do
+      regex_table[k] = rspamd_regexp.create_cached(v)
+    end
   end
+  return regex_table
+end
 
-  return fun.any(filter_func, task:get_text_parts())
+local function match_filter(task, found, patterns)
+  if type(patterns) ~= 'table' then return false end
+  if not patterns[1] then
+    for _, pat in pairs(patterns) do
+      if pat:match(found) then
+        return true
+      end
+    end
+    return false
+  else
+    for _, p in ipairs(patterns) do
+      for _, pat in ipairs(p) do
+        if pat:match(found) then
+          return true
+        end
+      end
+    end
+    return false
+  end
+end
+
+-- borrowed from mime_types.lua
+-- ext is the last extension, LOWERCASED
+-- ext2 is the one before last extension LOWERCASED
+local function gen_extension(fname)
+  local filename_parts = rspamd_str_split(fname, '.')
 
+  local ext = {}
+  for n = 1, 2 do
+      ext[n] = #filename_parts > n and string.lower(filename_parts[#filename_parts + 1 - n]) or nil
+  end
+  return ext[1],ext[2],filename_parts
 end
 
+local function check_parts_match(task, rule)
+
+  local filter_func = function(p)
+    local content_type,content_subtype = p:get_type()
+    local fname = p:get_filename()
+    local ext, ext2, part_table
+    local extension_check = false
+    local content_type_check = false
+    local text_part_min_words_check = true
+
+    if rule.scan_all_mime_parts == false then
+    -- check file extension and filename regex matching
+      if fname ~= nil then
+        ext,ext2,part_table = gen_extension(fname)
+        lua_util.debugm(rule.module_name, task, '%s: extension found: %s - 2.ext: %s - parts: %s',
+          rule.log_prefix, ext, ext2, part_table)
+        if match_filter(task, ext, rule.mime_parts_filter_ext)
+          or match_filter(task, ext2, rule.mime_parts_filter_ext) then
+          lua_util.debugm(rule.module_name, task, '%s: extension matched: %s', rule.log_prefix, ext)
+          extension_check = true
+        end
+        if match_filter(task, fname, rule.mime_parts_filter_regex) then
+          content_type_check = true
+        end
+      end
+      -- check content type regex matching
+      if content_type ~= nil and content_subtype ~= nil then
+        if match_filter(task, content_type..'/'..content_subtype, rule.mime_parts_filter_regex) then
+          lua_util.debugm(rule.module_name, task, '%s: regex ct: %s', rule.log_prefix,
+            content_type..'/'..content_subtype)
+          content_type_check = true
+        end
+      end
+    end
+
+    -- check text_part has more words than text_part_min_words_check
+    if rule.text_part_min_words and p:is_text() then
+      text_part_min_words_check = p:get_words_count() >= tonumber(rule.text_part_min_words)
+    end
+
+    return (rule.scan_image_mime and p:is_image())
+        or (rule.scan_text_mime and text_part_min_words_check)
+        or (p:is_attachment() and rule.scan_all_mime_parts ~= false)
+        or extension_check
+        or content_type_check
+  end
+
+  return fun.filter(filter_func, task:get_parts())
+end
 
 exports.yield_result = yield_result
 exports.match_patterns = match_patterns
 exports.need_av_check = need_av_check
 exports.check_av_cache = check_av_cache
 exports.save_av_cache = save_av_cache
-exports.text_parts_min_words = text_parts_min_words
+exports.create_regex_table = create_regex_table
+exports.check_parts_match = check_parts_match
 
 setmetatable(exports, {
   __call = function(t, override)
index 07dbba2f02973653d75ea23c246a089abca6094e..0515e3376f7c27aba56c37e379660a2124e13bc1 100644 (file)
@@ -19,6 +19,7 @@ local rspamd_regexp = require "rspamd_regexp"
 local lua_util = require "lua_util"
 local fun = require "fun"
 local lua_antivirus = require("lua_scanners").filter('antivirus')
+local common = require "lua_scanners/common"
 local redis_params
 
 local N = "antivirus"
@@ -107,26 +108,7 @@ local function add_antivirus_rule(sym, opts)
     return nil
   end
 
-  if type(opts['patterns']) == 'table' then
-    rule['patterns'] = {}
-    if opts['patterns'][1] then
-      for i, p in ipairs(opts['patterns']) do
-        if type(p) == 'table' then
-          local new_set = {}
-          for k, v in pairs(p) do
-            new_set[k] = rspamd_regexp.create_cached(v)
-          end
-          rule['patterns'][i] = new_set
-        else
-          rule['patterns'][i] = {}
-        end
-      end
-    else
-      for k, v in pairs(opts['patterns']) do
-        rule['patterns'][k] = rspamd_regexp.create_cached(v)
-      end
-    end
-  end
+  rule.patterns = common.create_regex_table(task, opts.patterns or {})
 
   if opts['whitelist'] then
     rule['whitelist'] = rspamd_config:add_hash_map(opts['whitelist'])
@@ -134,21 +116,13 @@ local function add_antivirus_rule(sym, opts)
 
   return function(task)
     if rule.scan_mime_parts then
-      local parts = task:get_parts() or {}
-
-      local filter_func = function(p)
-        return (rule.scan_image_mime and p:is_image())
-            or (rule.scan_text_mime and p:is_text())
-            or (p:is_attachment())
-      end
 
       fun.each(function(p)
         local content = p:get_content()
-
         if content and #content > 0 then
           cfg.check(task, content, p:get_digest(), rule)
         end
-      end, fun.filter(filter_func, parts))
+      end, common.check_parts_match(task, rule))
 
     else
       cfg.check(task, task:get_content(), task:get_digest(), rule)
index 192f15f515335b9efa128cf5140912074cb96a85..3a22c16d527b09fb9c2d5cd4442d4cf7aaffc3ed 100644 (file)
@@ -16,10 +16,10 @@ limitations under the License.
 ]] --
 
 local rspamd_logger = require "rspamd_logger"
-local rspamd_regexp = require "rspamd_regexp"
 local lua_util = require "lua_util"
 local fun = require "fun"
 local lua_scanners = require("lua_scanners").filter('scanner')
+local common = require "lua_scanners/common"
 local redis_params
 
 local N = "external_services"
@@ -62,7 +62,7 @@ local function add_scanner_rule(sym, opts)
   local cfg = lua_scanners[opts['type']]
 
   if not cfg then
-    rspamd_logger.errx(rspamd_config, 'unknown antivirus type: %s',
+    rspamd_logger.errx(rspamd_config, 'unknown external scanner type: %s',
         opts['type'])
     return nil
   end
@@ -82,126 +82,31 @@ local function add_scanner_rule(sym, opts)
     return nil
   end
 
-  local function create_regex_table(task, patterns)
-    local regex_table = {}
-    if patterns[1] then
-      for i, p in ipairs(patterns) do
-        if type(p) == 'table' then
-          local new_set = {}
-          for k, v in pairs(p) do
-            new_set[k] = rspamd_regexp.create_cached(v)
-          end
-          regex_table[i] = new_set
-        else
-          regex_table[i] = {}
-        end
-      end
-    else
-      for k, v in pairs(patterns) do
-        regex_table[k] = rspamd_regexp.create_cached(v)
-      end
-    end
-    return regex_table
-  end
-
-  if opts['mime_parts_filter_regex'] ~= nil
-    or opts['mime_parts_filter_ext'] ~= nil then
+  -- if any mime_part filter defined, do not scan all attachments
+  if opts.mime_parts_filter_regex ~= nil
+    or opts.mime_parts_filter_ext ~= nil then
       rule.scan_all_mime_parts = false
   end
 
-  rule['patterns'] = create_regex_table(task, opts['patterns'] or {})
+  rule.patterns = common.create_regex_table(task, opts.patterns or {})
 
-  rule['mime_parts_filter_regex'] = create_regex_table(task, opts['mime_parts_filter_regex'] or {})
+  rule.mime_parts_filter_regex = common.create_regex_table(task, opts.mime_parts_filter_regex or {})
 
-  rule['mime_parts_filter_ext'] = create_regex_table(task, opts['mime_parts_filter_ext'] or {})
+  rule.mime_parts_filter_ext = common.create_regex_table(task, opts.mime_parts_filter_ext or {})
 
   if opts['whitelist'] then
     rule['whitelist'] = rspamd_config:add_hash_map(opts['whitelist'])
   end
 
-  local function match_filter(task, found, patterns)
-    if type(patterns) ~= 'table' then
-      lua_util.debugm(N, task, '%s: pattern not table %s', rule.log_prefix, type(patterns))
-      return false
-    end
-    if not patterns[1] then
-      --lua_util.debugm(N, task, '%s: in not pattern[1]', rule['symbol'], rule['type'])
-      for _, pat in pairs(patterns) do
-        if pat:match(found) then
-          return true
-        end
-      end
-      return false
-    else
-      for _, p in ipairs(patterns) do
-        for _, pat in ipairs(p) do
-          if pat:match(found) then
-            return true
-          end
-        end
-      end
-      return false
-    end
-  end
-
-  -- borrowed from mime_types.lua
-  -- ext is the last extension, LOWERCASED
-  -- ext2 is the one before last extension LOWERCASED
-  local function gen_extension(fname)
-    local filename_parts = rspamd_str_split(fname, '.')
-
-    local ext = {}
-    for n = 1, 2 do
-        ext[n] = #filename_parts > n and string.lower(filename_parts[#filename_parts + 1 - n]) or nil
-    end
-  --lua_util.debugm(N, task, '%s: extension found: %s', rule.log_prefix, ext[1])
-    return ext[1],ext[2],filename_parts
-  end
-
   return function(task)
     if rule.scan_mime_parts then
-      local parts = task:get_parts() or {}
-
-      local filter_func = function(p)
-        local content_type,content_subtype = p:get_type()
-        local fname = p:get_filename()
-        local ext,ext2,part_table
-        local extension_check = false
-        local content_type_check = false
-        if fname ~= nil then
-          ext,ext2,part_table = gen_extension(fname)
-          lua_util.debugm(N, task, '%s: extension found: %s - 2.ext: %s - parts: %s',
-            rule.log_prefix, ext, ext2, part_table)
-          if match_filter(task, ext, rule['mime_parts_filter_ext'])
-            or match_filter(task, ext2, rule['mime_parts_filter_ext']) then
-            lua_util.debugm(N, task, '%s: extension matched: %s', rule.log_prefix, ext)
-            extension_check = true
-          end
-          if match_filter(task, fname, rule['mime_parts_filter_regex']) then
-            --lua_util.debugm(N, task, '%s: regex fname: %s', rule.log_prefix, fname)
-            content_type_check = true
-          end
-        end
-        if content_type ~=nil and content_subtype ~= nil then
-          if match_filter(task, content_type..'/'..content_subtype, rule['mime_parts_filter_regex']) then
-            lua_util.debugm(N, task, '%s: regex ct: %s', rule.log_prefix, content_type..'/'..content_subtype)
-            content_type_check = true
-          end
-        end
-
-        return (rule.scan_image_mime and p:is_image())
-            or (rule.scan_text_mime and p:is_text())
-            or (p:get_filename() and rule.scan_all_mime_parts ~= false)
-            or extension_check
-            or content_type_check
-      end
 
       fun.each(function(p)
         local content = p:get_content()
         if content and #content > 0 then
           cfg.check(task, content, p:get_digest(), rule)
         end
-      end, fun.filter(filter_func, parts))
+      end, common.check_parts_match(task, rule))
 
     else
       cfg.check(task, task:get_content(), task:get_digest(), rule)