From 055640c105492d4aaa8a75f973ce208ffd8cc045 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 6 Sep 2019 14:05:31 +0100 Subject: [PATCH] [Project] Lua_magic: Improve short patterns performance --- lualib/lua_magic/init.lua | 131 ++++++++++++++++++++++++++++---------- 1 file changed, 97 insertions(+), 34 deletions(-) diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua index 1ba899b06..a2b2c9882 100644 --- a/lualib/lua_magic/init.lua +++ b/lualib/lua_magic/init.lua @@ -31,17 +31,43 @@ local N = "lua_magic" local exports = {} -- trie object local compiled_patterns +local compiled_short_patterns -- short patterns -- {, , } indexed by pattern number local processed_patterns = {} +local short_patterns = {} + +local short_match_limit = 128 +local max_short_offset = -1 + +local function process_patterns(log_obj) + -- Add pattern to either short patterns or to normal patterns + local function add_processed(str, match, pattern) + if match.position and type(match.position) == 'number' and + match.position < short_match_limit then + short_patterns[#short_patterns + 1] = { + str, match, pattern + } + + if max_short_offset < match.position then + max_short_offset = match.position + end + else + processed_patterns[#processed_patterns + 1] = { + str, match, pattern + } + end + end -local function process_patterns() if not compiled_patterns then - for _,pattern in ipairs(patterns) do + for ext,pattern in pairs(patterns) do + assert(types[ext]) + pattern.ext = ext for _,match in ipairs(pattern.matches) do if match.string then - processed_patterns[#processed_patterns + 1] = { - match.string, match, pattern - } + if match.relative_position and not match.position then + match.position = match.relative_position + #match.string + end + add_processed(match.string, match, pattern) elseif match.hex then local hex_table = {} @@ -49,9 +75,11 @@ local function process_patterns() local subc = match.hex:sub(i, i + 1) hex_table[#hex_table + 1] = string.format('\\x{%s}', subc) end - processed_patterns[#processed_patterns + 1] = { - table.concat(hex_table), match, pattern - } + + if match.relative_position and not match.position then + match.position = match.relative_position + #match.hex / 2 + end + add_processed(table.concat(hex_table), match, pattern) end end end @@ -60,16 +88,19 @@ local function process_patterns() fun.map(function(t) return t[1] end, processed_patterns)), rspamd_trie.flags.re ) + compiled_short_patterns = rspamd_trie.create(fun.totable( + fun.map(function(t) return t[1] end, short_patterns)), + rspamd_trie.flags.re + ) - lua_util.debugm(N, rspamd_config, 'compiled %s patterns', - #processed_patterns) + lua_util.debugm(N, log_obj, + 'compiled %s (%s short and %s long) patterns', + #processed_patterns + #short_patterns, #short_patterns, #processed_patterns) end end -local function match_chunk(input, offset, log_obj, res) - local matches = compiled_patterns:match(input) - - if not log_obj then log_obj = rspamd_config end +local function match_chunk(input, offset, trie, processed_tbl, log_obj, res) + local matches = trie:match(input) local function add_result(match, pattern) if not res[pattern.ext] then @@ -86,7 +117,7 @@ local function match_chunk(input, offset, log_obj, res) end for npat,matched_positions in pairs(matches) do - local pat_data = processed_patterns[npat] + local pat_data = processed_tbl[npat] local pattern = pat_data[3] local match = pat_data[2] @@ -132,8 +163,25 @@ local function match_chunk(input, offset, log_obj, res) end end end + +local function process_detected(res) + local extensions = lua_util.keys(res) + + if #extensions > 0 then + table.sort(extensions, function(ex1, ex2) + return res[ex1] > res[ex2] + end) + + return extensions,res[extensions[1]] + end + + return nil +end + exports.detect = function(input, log_obj) - process_patterns() + if not log_obj then log_obj = rspamd_config end + process_patterns(log_obj) + local res = {} if type(input) == 'string' then @@ -141,28 +189,43 @@ exports.detect = function(input, log_obj) input = rspamd_text.fromstring(input) end - if type(input) == 'userdata' and #input > exports.chunk_size * 3 then - -- Split by chunks - local chunk1, chunk2, chunk3 = - input:span(1, exports.chunk_size), - input:span(exports.chunk_size, exports.chunk_size), - input:span(#input - exports.chunk_size, exports.chunk_size) - local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size - - match_chunk(chunk1, offset1, log_obj, res) - match_chunk(chunk2, offset2, log_obj, res) - match_chunk(chunk3, offset3, log_obj, res) + + if type(input) == 'userdata' then + -- Try short match + local head = input:span(1, math.min(max_short_offset, #input)) + match_chunk(head, 0, compiled_short_patterns, short_patterns, log_obj, res) + + local extensions,confidence = process_detected(res) + + if extensions and #extensions > 0 and confidence > 30 then + -- We are done on short patterns + return extensions[1],types[extensions[1]] + end + + if #input > exports.chunk_size * 3 then + -- Chunked version as input is too long + local chunk1, chunk2, chunk3 = + input:span(1, exports.chunk_size), + input:span(exports.chunk_size, exports.chunk_size), + input:span(#input - exports.chunk_size, exports.chunk_size) + local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size + + match_chunk(chunk1, offset1, compiled_patterns, processed_patterns, log_obj, res) + match_chunk(chunk2, offset2, compiled_patterns, processed_patterns, log_obj, res) + match_chunk(chunk3, offset3, compiled_patterns, processed_patterns, log_obj, res) + else + -- Input is short enough to match it at all + match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res) + end else - match_chunk(input, 0, log_obj, res) + -- Input is a table so just try to match it all... + match_chunk(input, 0, compiled_short_patterns, short_patterns, log_obj, res) + match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res) end - local extensions = lua_util.keys(res) - - if #extensions > 0 then - table.sort(extensions, function(ex1, ex2) - return res[ex1] > res[ex2] - end) + local extensions = process_detected(res) + if extensions and #extensions > 0 then return extensions[1],types[extensions[1]] end -- 2.39.5