From bd13783018884a90571d1e94754e8bbf81369b82 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 6 Sep 2019 17:14:47 +0100 Subject: [PATCH] [Project] Lua_magic: Support tail patterns --- lualib/lua_magic/init.lua | 102 ++++++++++++++++++++++++++++---------- 1 file changed, 76 insertions(+), 26 deletions(-) diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua index a2b2c9882..4ecc66afa 100644 --- a/lualib/lua_magic/init.lua +++ b/lualib/lua_magic/init.lua @@ -29,27 +29,39 @@ local rspamd_trie = require "rspamd_trie" local N = "lua_magic" local exports = {} --- trie object +-- trie objects local compiled_patterns -local compiled_short_patterns -- short patterns +local compiled_short_patterns +local compiled_tail_patterns -- {, , } indexed by pattern number local processed_patterns = {} local short_patterns = {} +local tail_patterns = {} local short_match_limit = 128 local max_short_offset = -1 +local min_tail_offset = math.huge local function process_patterns(log_obj) -- Add pattern to either short patterns or to normal patterns local function add_processed(str, match, pattern) - if match.position and type(match.position) == 'number' and - match.position < short_match_limit then - short_patterns[#short_patterns + 1] = { - str, match, pattern - } + if match.position and type(match.position) == 'number' then + if match.tail then + -- Tail pattern + tail_patterns[#tail_patterns + 1] = { + str, match, pattern + } + if min_tail_offset > match.tail then + min_tail_offset = match.tail + end + elseif match.position < short_match_limit then + short_patterns[#short_patterns + 1] = { + str, match, pattern + } - if max_short_offset < match.position then - max_short_offset = match.position + if max_short_offset < match.position then + max_short_offset = match.position + end end else processed_patterns[#processed_patterns + 1] = { @@ -92,15 +104,21 @@ local function process_patterns(log_obj) fun.map(function(t) return t[1] end, short_patterns)), rspamd_trie.flags.re ) + compiled_tail_patterns = rspamd_trie.create(fun.totable( + fun.map(function(t) return t[1] end, tail_patterns)), + rspamd_trie.flags.re + ) lua_util.debugm(N, log_obj, - 'compiled %s (%s short and %s long) patterns', - #processed_patterns + #short_patterns, #short_patterns, #processed_patterns) + 'compiled %s (%s short; %s long; %s tail) patterns', + #processed_patterns + #short_patterns + #tail_patterns, + #short_patterns, #processed_patterns, #tail_patterns) end end -local function match_chunk(input, offset, trie, processed_tbl, log_obj, res) +local function match_chunk(input, tlen, offset, trie, processed_tbl, log_obj, res) local matches = trie:match(input) + local last = tlen local function add_result(match, pattern) if not res[pattern.ext] then @@ -139,6 +157,11 @@ local function match_chunk(input, offset, trie, processed_tbl, log_obj, res) expected = expected[2] end + -- Tail match + if expected < 0 then + expected = last + expected + 1 + end + return cmp(pos, expected) end -- Single position @@ -146,19 +169,33 @@ local function match_chunk(input, offset, trie, processed_tbl, log_obj, res) local position = match.position for _,pos in ipairs(matched_positions) do + lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)', + pattern.ext, pos, offset) if match_position(pos + offset, position) then add_result(match, pattern) + break end end end -- Match all positions if match.positions then + local all_right = true for _,position in ipairs(match.positions) do + local matched = false for _,pos in ipairs(matched_positions) do - if match_position(pos, position) then - add_result(match, pattern) + if not match_position(pos + offset, position) then + matched = true + break end end + if not matched then + all_right = false + break + end + end + + if all_right then + add_result(match, pattern) end end end @@ -191,10 +228,21 @@ exports.detect = function(input, log_obj) if type(input) == 'userdata' then + local inplen = #input + + -- Check tail matches + if inplen > min_tail_offset then + local tail = input:span(inplen - min_tail_offset, min_tail_offset) + match_chunk(tail, inplen, inplen - min_tail_offset, + compiled_tail_patterns, tail_patterns, log_obj, res) + end + -- Try short match - local head = input:span(1, math.min(max_short_offset, #input)) - match_chunk(head, 0, compiled_short_patterns, short_patterns, log_obj, res) + local head = input:span(1, math.min(max_short_offset, inplen)) + match_chunk(head, inplen, 0, + compiled_short_patterns, short_patterns, log_obj, res) + -- Check if we have enough data or go to long patterns local extensions,confidence = process_detected(res) if extensions and #extensions > 0 and confidence > 30 then @@ -207,20 +255,22 @@ exports.detect = function(input, log_obj) local chunk1, chunk2, chunk3 = input:span(1, exports.chunk_size), input:span(exports.chunk_size, exports.chunk_size), - input:span(#input - exports.chunk_size, exports.chunk_size) - local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size - - match_chunk(chunk1, offset1, compiled_patterns, processed_patterns, log_obj, res) - match_chunk(chunk2, offset2, compiled_patterns, processed_patterns, log_obj, res) - match_chunk(chunk3, offset3, compiled_patterns, processed_patterns, log_obj, res) + input:span(inplen - exports.chunk_size, exports.chunk_size) + local offset1, offset2, offset3 = 0, exports.chunk_size, inplen - exports.chunk_size + + match_chunk(chunk1, inplen, + offset1, compiled_patterns, processed_patterns, log_obj, res) + match_chunk(chunk2, inplen, + offset2, compiled_patterns, processed_patterns, log_obj, res) + match_chunk(chunk3, inplen, + offset3, compiled_patterns, processed_patterns, log_obj, res) else -- Input is short enough to match it at all - match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res) + match_chunk(input, inplen, 0, + compiled_patterns, processed_patterns, log_obj, res) end else - -- Input is a table so just try to match it all... - match_chunk(input, 0, compiled_short_patterns, short_patterns, log_obj, res) - match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res) + assert(0) end local extensions = process_detected(res) -- 2.39.5