]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Improve pdf magic detection
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 11 May 2020 15:31:30 +0000 (16:31 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 11 May 2020 15:31:30 +0000 (16:31 +0100)
lualib/lua_magic/heuristics.lua
lualib/lua_magic/init.lua
lualib/lua_magic/patterns.lua

index c60824bd80503fb8780b6e5e802e21ac47a6658f..678ca1b6ddb975a6d17bf6ebc7fc93bea9098463 100644 (file)
@@ -147,7 +147,7 @@ end
 -- Call immediately on require
 compile_tries()
 
-local function detect_ole_format(input, log_obj)
+local function detect_ole_format(input, log_obj, _, part)
   local inplen = #input
   if inplen < 0x31 + 4 then
     lua_util.debugm(N, log_obj, "short length: %s", inplen)
@@ -245,7 +245,7 @@ local function process_top_detected(res)
   return nil
 end
 
-local function detect_archive_flaw(part, arch, log_obj)
+local function detect_archive_flaw(part, arch, log_obj, _)
   local arch_type = arch:get_type()
   local res = {
     docx = 0,
@@ -312,7 +312,7 @@ local function detect_archive_flaw(part, arch, log_obj)
   return arch_type:lower(),40
 end
 
-exports.mime_part_heuristic = function(part, log_obj)
+exports.mime_part_heuristic = function(part, log_obj, _)
   if part:is_archive() then
     local arch = part:get_archive()
     return detect_archive_flaw(part, arch, log_obj)
@@ -321,7 +321,7 @@ exports.mime_part_heuristic = function(part, log_obj)
   return nil
 end
 
-exports.text_part_heuristic = function(part, log_obj)
+exports.text_part_heuristic = function(part, log_obj, _)
   -- We get some span of data and check it
   local function is_span_text(span)
     local function rough_utf8_check(bytes, idx, remain)
@@ -436,4 +436,19 @@ exports.text_part_heuristic = function(part, log_obj)
   end
 end
 
+exports.pdf_format_heuristic = function(input, log_obj, pos, part)
+  local weight = 10
+  local ext = string.match(part:get_filename() or '', '%.([^.]+)$')
+  -- If we found a pattern at the beginning
+  if pos <= 10 then
+    weight = weight + 30
+  end
+  -- If the announced extension is `pdf`
+  if ext and ext:lower() == 'pdf' then
+    weight = weight + 30
+  end
+
+  return 'pdf',weight
+end
+
 return exports
\ No newline at end of file
index 2bfc067e365633ccd8b799a20f302a5c7b68a30f..890afea6820bb47d2d55e6ebf5771fcf7ec597c2 100644 (file)
@@ -152,7 +152,7 @@ end
 
 process_patterns(rspamd_config)
 
-local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res)
+local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res, part)
   local matches = trie:match(chunk)
 
   local last = tlen
@@ -210,7 +210,7 @@ local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_
             pattern.ext, pos, offset)
         if match_position(pos + offset, position) then
           if match.heuristic then
-            local ext,weight = match.heuristic(input, log_obj)
+            local ext,weight = match.heuristic(input, log_obj, pos + offset, part)
 
             if ext then
               add_result(weight, ext)
@@ -225,6 +225,7 @@ local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_
     elseif match.positions then
       -- Match all positions
       local all_right = true
+      local matched_pos = 0
       for _,position in ipairs(match.positions) do
         local matched = false
         for _,pos in ipairs(matched_positions) do
@@ -232,6 +233,7 @@ local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_
               pattern.ext, pos, offset)
           if not match_position(pos + offset, position) then
             matched = true
+            matched_pos = pos
             break
           end
         end
@@ -243,7 +245,7 @@ local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_
 
       if all_right then
         if match.heuristic then
-          local ext,weight = match.heuristic(input, log_obj)
+          local ext,weight = match.heuristic(input, log_obj, matched_pos + offset, part)
 
           if ext then
             add_result(weight, ext)
@@ -273,8 +275,9 @@ local function process_detected(res)
   return nil
 end
 
-exports.detect = function(input, log_obj)
+exports.detect = function(part, log_obj)
   if not log_obj then log_obj = rspamd_config end
+  local input = part:get_content()
 
   local res = {}
 
@@ -291,13 +294,13 @@ exports.detect = function(input, log_obj)
     if inplen > min_tail_offset then
       local tail = input:span(inplen - min_tail_offset, min_tail_offset)
       match_chunk(tail, input, inplen, inplen - min_tail_offset,
-          compiled_tail_patterns, tail_patterns, log_obj, res)
+          compiled_tail_patterns, tail_patterns, log_obj, res, part)
     end
 
     -- Try short match
     local head = input:span(1, math.min(max_short_offset, inplen))
     match_chunk(head, input, inplen, 0,
-        compiled_short_patterns, short_patterns, log_obj, res)
+        compiled_short_patterns, short_patterns, log_obj, res, part)
 
     -- Check if we have enough data or go to long patterns
     local extensions,confidence = process_detected(res)
@@ -316,13 +319,13 @@ exports.detect = function(input, log_obj)
       local offset1, offset2 = 0, inplen - exports.chunk_size
 
       match_chunk(chunk1, input, inplen,
-          offset1, compiled_patterns, processed_patterns, log_obj, res)
+          offset1, compiled_patterns, processed_patterns, log_obj, res, part)
       match_chunk(chunk2, input, inplen,
-          offset2, compiled_patterns, processed_patterns, log_obj, res)
+          offset2, compiled_patterns, processed_patterns, log_obj, res, part)
     else
       -- Input is short enough to match it at all
       match_chunk(input, input, inplen, 0,
-          compiled_patterns, processed_patterns, log_obj, res)
+          compiled_patterns, processed_patterns, log_obj, res, part)
     end
   else
     -- Table input is NYI
@@ -346,7 +349,7 @@ exports.detect_mime_part = function(part, log_obj)
     return ext,types[ext]
   end
 
-  ext = exports.detect(part:get_content(), log_obj)
+  ext = exports.detect(part, log_obj)
 
   if ext then
     return ext,types[ext]
index c1ea896599da8373882b4c0c7a9eea2afc5a4830..87583c9de22614715f1825200f39c5c512da706a 100644 (file)
@@ -27,18 +27,21 @@ local patterns = {
     matches = {
       {
         string = [[^%PDF-\d]],
-        position = 6, -- must be end of the match, as that's how hyperscan works (or use relative_position)
+        position = {'<=', 1024},
         weight = 60,
+        heuristic = heuristics.pdf_format_heuristic
       },
       {
         string = [[^\012%PDF-\d]],
-        position = 7,
+        position = {'<=', 1024},
         weight = 60,
+        heuristic = heuristics.pdf_format_heuristic
       },
       {
         string = [[^%FDF-\d]],
-        position = 6,
+        position = {'<=', 1024},
         weight = 60,
+        heuristic = heuristics.pdf_format_heuristic
       },
     },
   },