]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Lua_magic: Improve short patterns performance
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 6 Sep 2019 13:05:31 +0000 (14:05 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 6 Sep 2019 13:05:31 +0000 (14:05 +0100)
lualib/lua_magic/init.lua

index 1ba899b0698c944366348510b29c115654b58de7..a2b2c98820d57df2444c29a74065f4cda2372c12 100644 (file)
@@ -31,17 +31,43 @@ local N = "lua_magic"
 local exports = {}
 -- trie object
 local compiled_patterns
+local compiled_short_patterns -- short patterns
 -- {<str>, <match_object>, <pattern_object>} indexed by pattern number
 local processed_patterns = {}
+local short_patterns = {}
+
+local short_match_limit = 128
+local max_short_offset = -1
+
+local function process_patterns(log_obj)
+  -- Add pattern to either short patterns or to normal patterns
+  local function add_processed(str, match, pattern)
+    if match.position and type(match.position) == 'number' and
+        match.position < short_match_limit then
+      short_patterns[#short_patterns + 1] = {
+        str, match, pattern
+      }
+
+      if max_short_offset < match.position then
+        max_short_offset = match.position
+      end
+    else
+      processed_patterns[#processed_patterns + 1] = {
+        str, match, pattern
+      }
+    end
+  end
 
-local function process_patterns()
   if not compiled_patterns then
-    for _,pattern in ipairs(patterns) do
+    for ext,pattern in pairs(patterns) do
+      assert(types[ext])
+      pattern.ext = ext
       for _,match in ipairs(pattern.matches) do
         if match.string then
-          processed_patterns[#processed_patterns + 1] = {
-            match.string, match, pattern
-          }
+          if match.relative_position and not match.position then
+            match.position = match.relative_position + #match.string
+          end
+          add_processed(match.string, match, pattern)
         elseif match.hex then
           local hex_table = {}
 
@@ -49,9 +75,11 @@ local function process_patterns()
             local subc = match.hex:sub(i, i + 1)
             hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
           end
-          processed_patterns[#processed_patterns + 1] = {
-            table.concat(hex_table), match, pattern
-          }
+
+          if match.relative_position and not match.position then
+            match.position = match.relative_position + #match.hex / 2
+          end
+          add_processed(table.concat(hex_table), match, pattern)
         end
       end
     end
@@ -60,16 +88,19 @@ local function process_patterns()
         fun.map(function(t) return t[1] end, processed_patterns)),
         rspamd_trie.flags.re
     )
+    compiled_short_patterns = rspamd_trie.create(fun.totable(
+        fun.map(function(t) return t[1] end, short_patterns)),
+        rspamd_trie.flags.re
+    )
 
-    lua_util.debugm(N, rspamd_config, 'compiled %s patterns',
-        #processed_patterns)
+    lua_util.debugm(N, log_obj,
+        'compiled %s (%s short and %s long) patterns',
+        #processed_patterns + #short_patterns, #short_patterns, #processed_patterns)
   end
 end
 
-local function match_chunk(input, offset, log_obj, res)
-  local matches = compiled_patterns:match(input)
-
-  if not log_obj then log_obj = rspamd_config end
+local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
+  local matches = trie:match(input)
 
   local function add_result(match, pattern)
     if not res[pattern.ext] then
@@ -86,7 +117,7 @@ local function match_chunk(input, offset, log_obj, res)
   end
 
   for npat,matched_positions in pairs(matches) do
-    local pat_data = processed_patterns[npat]
+    local pat_data = processed_tbl[npat]
     local pattern = pat_data[3]
     local match = pat_data[2]
 
@@ -132,8 +163,25 @@ local function match_chunk(input, offset, log_obj, res)
     end
   end
 end
+
+local function process_detected(res)
+  local extensions = lua_util.keys(res)
+
+  if #extensions > 0 then
+    table.sort(extensions, function(ex1, ex2)
+      return res[ex1] > res[ex2]
+    end)
+
+    return extensions,res[extensions[1]]
+  end
+
+  return nil
+end
+
 exports.detect = function(input, log_obj)
-  process_patterns()
+  if not log_obj then log_obj = rspamd_config end
+  process_patterns(log_obj)
+
   local res = {}
 
   if type(input) == 'string' then
@@ -141,28 +189,43 @@ exports.detect = function(input, log_obj)
     input = rspamd_text.fromstring(input)
   end
 
-  if type(input) == 'userdata' and #input > exports.chunk_size * 3 then
-    -- Split by chunks
-    local chunk1, chunk2, chunk3 =
-    input:span(1, exports.chunk_size),
-    input:span(exports.chunk_size, exports.chunk_size),
-    input:span(#input - exports.chunk_size, exports.chunk_size)
-    local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
-
-    match_chunk(chunk1, offset1, log_obj, res)
-    match_chunk(chunk2, offset2, log_obj, res)
-    match_chunk(chunk3, offset3, log_obj, res)
+
+  if type(input) == 'userdata' then
+    -- Try short match
+    local head = input:span(1, math.min(max_short_offset, #input))
+    match_chunk(head, 0, compiled_short_patterns, short_patterns, log_obj, res)
+
+    local extensions,confidence = process_detected(res)
+
+    if extensions and #extensions > 0 and confidence > 30 then
+      -- We are done on short patterns
+      return extensions[1],types[extensions[1]]
+    end
+
+    if #input > exports.chunk_size * 3 then
+      -- Chunked version as input is too long
+      local chunk1, chunk2, chunk3 =
+      input:span(1, exports.chunk_size),
+      input:span(exports.chunk_size, exports.chunk_size),
+      input:span(#input - exports.chunk_size, exports.chunk_size)
+      local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
+
+      match_chunk(chunk1, offset1, compiled_patterns, processed_patterns, log_obj, res)
+      match_chunk(chunk2, offset2, compiled_patterns, processed_patterns, log_obj, res)
+      match_chunk(chunk3, offset3, compiled_patterns, processed_patterns, log_obj, res)
+    else
+      -- Input is short enough to match it at all
+      match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
+    end
   else
-    match_chunk(input, 0, log_obj, res)
+    -- Input is a table so just try to match it all...
+    match_chunk(input, 0, compiled_short_patterns, short_patterns, log_obj, res)
+    match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
   end
 
-  local extensions = lua_util.keys(res)
-
-  if #extensions > 0 then
-    table.sort(extensions, function(ex1, ex2)
-      return res[ex1] > res[ex2]
-    end)
+  local extensions = process_detected(res)
 
+  if extensions and #extensions > 0 then
     return extensions[1],types[extensions[1]]
   end