mirror of
https://github.com/rspamd/rspamd.git
synced 2024-08-22 05:54:43 +02:00
[Project] Lua_magic: Improve short patterns performance
This commit is contained in:
parent
5aed65dc5c
commit
055640c105
@ -31,17 +31,43 @@ local N = "lua_magic"
|
||||
local exports = {}
|
||||
-- trie object
|
||||
local compiled_patterns
|
||||
local compiled_short_patterns -- short patterns
|
||||
-- {<str>, <match_object>, <pattern_object>} indexed by pattern number
|
||||
local processed_patterns = {}
|
||||
local short_patterns = {}
|
||||
|
||||
local short_match_limit = 128
|
||||
local max_short_offset = -1
|
||||
|
||||
local function process_patterns(log_obj)
|
||||
-- Add pattern to either short patterns or to normal patterns
|
||||
local function add_processed(str, match, pattern)
|
||||
if match.position and type(match.position) == 'number' and
|
||||
match.position < short_match_limit then
|
||||
short_patterns[#short_patterns + 1] = {
|
||||
str, match, pattern
|
||||
}
|
||||
|
||||
if max_short_offset < match.position then
|
||||
max_short_offset = match.position
|
||||
end
|
||||
else
|
||||
processed_patterns[#processed_patterns + 1] = {
|
||||
str, match, pattern
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
local function process_patterns()
|
||||
if not compiled_patterns then
|
||||
for _,pattern in ipairs(patterns) do
|
||||
for ext,pattern in pairs(patterns) do
|
||||
assert(types[ext])
|
||||
pattern.ext = ext
|
||||
for _,match in ipairs(pattern.matches) do
|
||||
if match.string then
|
||||
processed_patterns[#processed_patterns + 1] = {
|
||||
match.string, match, pattern
|
||||
}
|
||||
if match.relative_position and not match.position then
|
||||
match.position = match.relative_position + #match.string
|
||||
end
|
||||
add_processed(match.string, match, pattern)
|
||||
elseif match.hex then
|
||||
local hex_table = {}
|
||||
|
||||
@ -49,9 +75,11 @@ local function process_patterns()
|
||||
local subc = match.hex:sub(i, i + 1)
|
||||
hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
|
||||
end
|
||||
processed_patterns[#processed_patterns + 1] = {
|
||||
table.concat(hex_table), match, pattern
|
||||
}
|
||||
|
||||
if match.relative_position and not match.position then
|
||||
match.position = match.relative_position + #match.hex / 2
|
||||
end
|
||||
add_processed(table.concat(hex_table), match, pattern)
|
||||
end
|
||||
end
|
||||
end
|
||||
@ -60,16 +88,19 @@ local function process_patterns()
|
||||
fun.map(function(t) return t[1] end, processed_patterns)),
|
||||
rspamd_trie.flags.re
|
||||
)
|
||||
compiled_short_patterns = rspamd_trie.create(fun.totable(
|
||||
fun.map(function(t) return t[1] end, short_patterns)),
|
||||
rspamd_trie.flags.re
|
||||
)
|
||||
|
||||
lua_util.debugm(N, rspamd_config, 'compiled %s patterns',
|
||||
#processed_patterns)
|
||||
lua_util.debugm(N, log_obj,
|
||||
'compiled %s (%s short and %s long) patterns',
|
||||
#processed_patterns + #short_patterns, #short_patterns, #processed_patterns)
|
||||
end
|
||||
end
|
||||
|
||||
local function match_chunk(input, offset, log_obj, res)
|
||||
local matches = compiled_patterns:match(input)
|
||||
|
||||
if not log_obj then log_obj = rspamd_config end
|
||||
local function match_chunk(input, offset, trie, processed_tbl, log_obj, res)
|
||||
local matches = trie:match(input)
|
||||
|
||||
local function add_result(match, pattern)
|
||||
if not res[pattern.ext] then
|
||||
@ -86,7 +117,7 @@ local function match_chunk(input, offset, log_obj, res)
|
||||
end
|
||||
|
||||
for npat,matched_positions in pairs(matches) do
|
||||
local pat_data = processed_patterns[npat]
|
||||
local pat_data = processed_tbl[npat]
|
||||
local pattern = pat_data[3]
|
||||
local match = pat_data[2]
|
||||
|
||||
@ -132,30 +163,8 @@ local function match_chunk(input, offset, log_obj, res)
|
||||
end
|
||||
end
|
||||
end
|
||||
exports.detect = function(input, log_obj)
|
||||
process_patterns()
|
||||
local res = {}
|
||||
|
||||
if type(input) == 'string' then
|
||||
-- Convert to rspamd_text
|
||||
input = rspamd_text.fromstring(input)
|
||||
end
|
||||
|
||||
if type(input) == 'userdata' and #input > exports.chunk_size * 3 then
|
||||
-- Split by chunks
|
||||
local chunk1, chunk2, chunk3 =
|
||||
input:span(1, exports.chunk_size),
|
||||
input:span(exports.chunk_size, exports.chunk_size),
|
||||
input:span(#input - exports.chunk_size, exports.chunk_size)
|
||||
local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
|
||||
|
||||
match_chunk(chunk1, offset1, log_obj, res)
|
||||
match_chunk(chunk2, offset2, log_obj, res)
|
||||
match_chunk(chunk3, offset3, log_obj, res)
|
||||
else
|
||||
match_chunk(input, 0, log_obj, res)
|
||||
end
|
||||
|
||||
local function process_detected(res)
|
||||
local extensions = lua_util.keys(res)
|
||||
|
||||
if #extensions > 0 then
|
||||
@ -163,6 +172,60 @@ exports.detect = function(input, log_obj)
|
||||
return res[ex1] > res[ex2]
|
||||
end)
|
||||
|
||||
return extensions,res[extensions[1]]
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
exports.detect = function(input, log_obj)
|
||||
if not log_obj then log_obj = rspamd_config end
|
||||
process_patterns(log_obj)
|
||||
|
||||
local res = {}
|
||||
|
||||
if type(input) == 'string' then
|
||||
-- Convert to rspamd_text
|
||||
input = rspamd_text.fromstring(input)
|
||||
end
|
||||
|
||||
|
||||
if type(input) == 'userdata' then
|
||||
-- Try short match
|
||||
local head = input:span(1, math.min(max_short_offset, #input))
|
||||
match_chunk(head, 0, compiled_short_patterns, short_patterns, log_obj, res)
|
||||
|
||||
local extensions,confidence = process_detected(res)
|
||||
|
||||
if extensions and #extensions > 0 and confidence > 30 then
|
||||
-- We are done on short patterns
|
||||
return extensions[1],types[extensions[1]]
|
||||
end
|
||||
|
||||
if #input > exports.chunk_size * 3 then
|
||||
-- Chunked version as input is too long
|
||||
local chunk1, chunk2, chunk3 =
|
||||
input:span(1, exports.chunk_size),
|
||||
input:span(exports.chunk_size, exports.chunk_size),
|
||||
input:span(#input - exports.chunk_size, exports.chunk_size)
|
||||
local offset1, offset2, offset3 = 0, exports.chunk_size, #input - exports.chunk_size
|
||||
|
||||
match_chunk(chunk1, offset1, compiled_patterns, processed_patterns, log_obj, res)
|
||||
match_chunk(chunk2, offset2, compiled_patterns, processed_patterns, log_obj, res)
|
||||
match_chunk(chunk3, offset3, compiled_patterns, processed_patterns, log_obj, res)
|
||||
else
|
||||
-- Input is short enough to match it at all
|
||||
match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
|
||||
end
|
||||
else
|
||||
-- Input is a table so just try to match it all...
|
||||
match_chunk(input, 0, compiled_short_patterns, short_patterns, log_obj, res)
|
||||
match_chunk(input, 0, compiled_patterns, processed_patterns, log_obj, res)
|
||||
end
|
||||
|
||||
local extensions = process_detected(res)
|
||||
|
||||
if extensions and #extensions > 0 then
|
||||
return extensions[1],types[extensions[1]]
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user