]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Lua_magic: Support tail patterns
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 6 Sep 2019 16:14:47 +0000 (17:14 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 6 Sep 2019 16:14:47 +0000 (17:14 +0100)
lualib/lua_magic/init.lua

index a2b2c98820d57df2444c29a74065f4cda2372c12..4ecc66afa744bc3b8ee326cb562d182e1aa46c68 100644 (file)
@@ -29,27 +29,39 @@ local rspamd_trie = require "rspamd_trie"
 
 local N = "lua_magic"
 local exports = {}
--- trie object
+-- trie objects
 local compiled_patterns
-local compiled_short_patterns -- short patterns
+local compiled_short_patterns
+local compiled_tail_patterns
 -- {<str>, <match_object>, <pattern_object>} indexed by pattern number
 local processed_patterns = {}
 local short_patterns = {}
+local tail_patterns = {}
 
 local short_match_limit = 128
 local max_short_offset = -1
+local min_tail_offset = math.huge
 
 local function process_patterns(log_obj)
   -- Add pattern to either short patterns or to normal patterns
   local function add_processed(str, match, pattern)
-    if match.position and type(match.position) == 'number' and
-        match.position < short_match_limit then
-      short_patterns[#short_patterns + 1] = {
-        str, match, pattern
-      }
+    if match.position and type(match.position) == 'number' then
+      if match.tail then
+        -- Tail pattern
+        tail_patterns[#tail_patterns + 1] = {
+          str, match, pattern
+        }
+        if min_tail_offset > match.tail then
+          min_tail_offset = match.tail
+        end
+      elseif match.position < short_match_limit then
+        short_patterns[#short_patterns + 1] = {
+          str, match, pattern
+        }
 
-      if max_short_offset < match.position then
-        max_short_offset = match.position
+        if max_short_offset < match.position then
+          max_short_offset = match.position
+        end
       end
     else
       processed_patterns[#processed_patterns + 1] = {
@@ -92,15 +104,21 @@ local function process_patterns(log_obj)
         fun.map(function(t) return t[1] end, short_patterns)),
         rspamd_trie.flags.re
     )
+    compiled_tail_patterns = rspamd_trie.create(fun.totable(
+        fun.map(function(t) return t[1] end, tail_patterns)),
+        rspamd_trie.flags.re
+    )
 
     lua_util.debugm(N, log_obj,
-        'compiled %s (%s short and %s long) patterns',
-        #processed_patterns + #short_patterns, #short_patterns, #processed_patterns)
+        'compiled %s (%s short; %s long; %s tail) patterns',
+        #processed_patterns + #short_patterns + #tail_patterns,
+        #short_patterns, #processed_patterns, #tail_patterns)
   end
 end
 
-local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
+local function match_chunk(input, tlen, offset, trie, processed_tbl, log_obj, res)
   local matches = trie:match(input)
+  local last = tlen
 
   local function add_result(match, pattern)
     if not res[pattern.ext] then
@@ -139,6 +157,11 @@ local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
         expected = expected[2]
       end
 
+      -- Tail match
+      if expected < 0 then
+        expected = last + expected + 1
+      end
+
       return cmp(pos, expected)
     end
     -- Single position
@@ -146,19 +169,33 @@ local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
       local position = match.position
 
       for _,pos in ipairs(matched_positions) do
+        lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
+            pattern.ext, pos, offset)
         if match_position(pos + offset, position) then
           add_result(match, pattern)
+          break
         end
       end
     end
     -- Match all positions
     if match.positions then
+      local all_right = true
       for _,position in ipairs(match.positions) do
+        local matched = false
         for _,pos in ipairs(matched_positions) do
-          if match_position(pos, position) then
-            add_result(match, pattern)
+          if not match_position(pos + offset, position) then
+            matched = true
+            break
           end
         end
+        if not matched then
+          all_right = false
+          break
+        end
+      end
+
+      if all_right then
+        add_result(match, pattern)
       end
     end
   end
@@ -191,10 +228,21 @@ exports.detect = function(input, log_obj)
 
 
   if type(input) == 'userdata' then
+    local inplen = #input
+
+    -- Check tail matches
+    if inplen > min_tail_offset then
+      local tail = input:span(inplen - min_tail_offset, min_tail_offset)
+      match_chunk(tail, inplen, inplen - min_tail_offset,
+          compiled_tail_patterns, tail_patterns, log_obj, res)
+    end
+
     -- Try short match
-    local head = input:span(1, math.min(max_short_offset, #input))
-    match_chunk(head, 0, compiled_short_patterns, short_patterns, log_obj, res)
+    local head = input:span(1, math.min(max_short_offset, inplen))
+    match_chunk(head, inplen, 0,
+        compiled_short_patterns, short_patterns, log_obj, res)
 
+    -- Check if we have enough data or go to long patterns
     local extensions,confidence = process_detected(res)
 
     if extensions and #extensions > 0 and confidence > 30 then
@@ -207,20 +255,22 @@ exports.detect = function(input, log_obj)
       local chunk1, chunk2, chunk3 =
       input:span(1, exports.chunk_size),
       input:span(exports.chunk_size, exports.chunk_size),
-      input:span(#input - exports.chunk_size, exports.chunk_size)
-      local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
-
-      match_chunk(chunk1, offset1, compiled_patterns, processed_patterns, log_obj, res)
-      match_chunk(chunk2, offset2, compiled_patterns, processed_patterns, log_obj, res)
-      match_chunk(chunk3, offset3, compiled_patterns, processed_patterns, log_obj, res)
+      input:span(inplen - exports.chunk_size, exports.chunk_size)
+      local offset1, offset2, offset3 = 0, exports.chunk_size, inplen - exports.chunk_size
+
+      match_chunk(chunk1, inplen,
+          offset1, compiled_patterns, processed_patterns, log_obj, res)
+      match_chunk(chunk2, inplen,
+          offset2, compiled_patterns, processed_patterns, log_obj, res)
+      match_chunk(chunk3, inplen,
+          offset3, compiled_patterns, processed_patterns, log_obj, res)
     else
       -- Input is short enough to match it at all
-      match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
+      match_chunk(input, inplen, 0,
+          compiled_patterns, processed_patterns, log_obj, res)
     end
   else
-    -- Input is a table so just try to match it all...
-    match_chunk(input, 0, compiled_short_patterns, short_patterns, log_obj, res)
-    match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
+    assert(0)
   end
 
   local extensions = process_detected(res)