aboutsummaryrefslogtreecommitdiffstats
path: root/lualib
diff options
context:
space:
mode:
Diffstat (limited to 'lualib')
-rw-r--r--lualib/lua_magic/init.lua102
1 files changed, 76 insertions, 26 deletions
diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua
index a2b2c9882..4ecc66afa 100644
--- a/lualib/lua_magic/init.lua
+++ b/lualib/lua_magic/init.lua
@@ -29,27 +29,39 @@ local rspamd_trie = require "rspamd_trie"
local N = "lua_magic"
local exports = {}
--- trie object
+-- trie objects
local compiled_patterns
-local compiled_short_patterns -- short patterns
+local compiled_short_patterns
+local compiled_tail_patterns
-- {<str>, <match_object>, <pattern_object>} indexed by pattern number
local processed_patterns = {}
local short_patterns = {}
+local tail_patterns = {}
local short_match_limit = 128
local max_short_offset = -1
+local min_tail_offset = math.huge
local function process_patterns(log_obj)
-- Add pattern to either short patterns or to normal patterns
local function add_processed(str, match, pattern)
- if match.position and type(match.position) == 'number' and
- match.position < short_match_limit then
- short_patterns[#short_patterns + 1] = {
- str, match, pattern
- }
+ if match.position and type(match.position) == 'number' then
+ if match.tail then
+ -- Tail pattern
+ tail_patterns[#tail_patterns + 1] = {
+ str, match, pattern
+ }
+ if min_tail_offset > match.tail then
+ min_tail_offset = match.tail
+ end
+ elseif match.position < short_match_limit then
+ short_patterns[#short_patterns + 1] = {
+ str, match, pattern
+ }
- if max_short_offset < match.position then
- max_short_offset = match.position
+ if max_short_offset < match.position then
+ max_short_offset = match.position
+ end
end
else
processed_patterns[#processed_patterns + 1] = {
@@ -92,15 +104,21 @@ local function process_patterns(log_obj)
fun.map(function(t) return t[1] end, short_patterns)),
rspamd_trie.flags.re
)
+ compiled_tail_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t) return t[1] end, tail_patterns)),
+ rspamd_trie.flags.re
+ )
lua_util.debugm(N, log_obj,
- 'compiled %s (%s short and %s long) patterns',
- #processed_patterns + #short_patterns, #short_patterns, #processed_patterns)
+ 'compiled %s (%s short; %s long; %s tail) patterns',
+ #processed_patterns + #short_patterns + #tail_patterns,
+ #short_patterns, #processed_patterns, #tail_patterns)
end
end
-local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
+local function match_chunk(input, tlen, offset, trie, processed_tbl, log_obj, res)
local matches = trie:match(input)
+ local last = tlen
local function add_result(match, pattern)
if not res[pattern.ext] then
@@ -139,6 +157,11 @@ local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
expected = expected[2]
end
+ -- Tail match
+ if expected < 0 then
+ expected = last + expected + 1
+ end
+
return cmp(pos, expected)
end
-- Single position
@@ -146,19 +169,33 @@ local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
local position = match.position
for _,pos in ipairs(matched_positions) do
+ lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
+ pattern.ext, pos, offset)
if match_position(pos + offset, position) then
add_result(match, pattern)
+ break
end
end
end
-- Match all positions
if match.positions then
+ local all_right = true
for _,position in ipairs(match.positions) do
+ local matched = false
for _,pos in ipairs(matched_positions) do
- if match_position(pos, position) then
- add_result(match, pattern)
+ if not match_position(pos + offset, position) then
+ matched = true
+ break
end
end
+ if not matched then
+ all_right = false
+ break
+ end
+ end
+
+ if all_right then
+ add_result(match, pattern)
end
end
end
@@ -191,10 +228,21 @@ exports.detect = function(input, log_obj)
if type(input) == 'userdata' then
+ local inplen = #input
+
+ -- Check tail matches
+ if inplen > min_tail_offset then
+ local tail = input:span(inplen - min_tail_offset, min_tail_offset)
+ match_chunk(tail, inplen, inplen - min_tail_offset,
+ compiled_tail_patterns, tail_patterns, log_obj, res)
+ end
+
-- Try short match
- local head = input:span(1, math.min(max_short_offset, #input))
- match_chunk(head, 0, compiled_short_patterns, short_patterns, log_obj, res)
+ local head = input:span(1, math.min(max_short_offset, inplen))
+ match_chunk(head, inplen, 0,
+ compiled_short_patterns, short_patterns, log_obj, res)
+ -- Check if we have enough data or go to long patterns
local extensions,confidence = process_detected(res)
if extensions and #extensions > 0 and confidence > 30 then
@@ -207,20 +255,22 @@ exports.detect = function(input, log_obj)
local chunk1, chunk2, chunk3 =
input:span(1, exports.chunk_size),
input:span(exports.chunk_size, exports.chunk_size),
- input:span(#input - exports.chunk_size, exports.chunk_size)
- local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
-
- match_chunk(chunk1, offset1, compiled_patterns, processed_patterns, log_obj, res)
- match_chunk(chunk2, offset2, compiled_patterns, processed_patterns, log_obj, res)
- match_chunk(chunk3, offset3, compiled_patterns, processed_patterns, log_obj, res)
+ input:span(inplen - exports.chunk_size, exports.chunk_size)
+ local offset1, offset2, offset3 = 0, exports.chunk_size, inplen - exports.chunk_size
+
+ match_chunk(chunk1, inplen,
+ offset1, compiled_patterns, processed_patterns, log_obj, res)
+ match_chunk(chunk2, inplen,
+ offset2, compiled_patterns, processed_patterns, log_obj, res)
+ match_chunk(chunk3, inplen,
+ offset3, compiled_patterns, processed_patterns, log_obj, res)
else
-- Input is short enough to match it at all
- match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
+ match_chunk(input, inplen, 0,
+ compiled_patterns, processed_patterns, log_obj, res)
end
else
- -- Input is a table so just try to match it all...
- match_chunk(input, 0, compiled_short_patterns, short_patterns, log_obj, res)
- match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
+ assert(0)
end
local extensions = process_detected(res)