123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- --[[
- Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ]]--
-
- --[[[
- -- @module lua_magic
- -- This module contains file types detection logic
- --]]
-
- local patterns = require "lua_magic/patterns"
- local types = require "lua_magic/types"
- local heuristics = require "lua_magic/heuristics"
- local fun = require "fun"
- local lua_util = require "lua_util"
-
- local rspamd_text = require "rspamd_text"
- local rspamd_trie = require "rspamd_trie"
-
- local N = "lua_magic"
- local exports = {}
- -- trie objects
- local compiled_patterns
- local compiled_short_patterns
- local compiled_tail_patterns
- -- {<str>, <match_object>, <pattern_object>} indexed by pattern number
- local processed_patterns = {}
- local short_patterns = {}
- local tail_patterns = {}
-
- local short_match_limit = 128
- local max_short_offset = -1
- local min_tail_offset = math.huge
-
- local function process_patterns(log_obj)
- -- Add pattern to either short patterns or to normal patterns
- local function add_processed(str, match, pattern)
- if match.position and type(match.position) == 'number' then
- if match.tail then
- -- Tail pattern
- tail_patterns[#tail_patterns + 1] = {
- str, match, pattern
- }
- if min_tail_offset > match.tail then
- min_tail_offset = match.tail
- end
-
- lua_util.debugm(N, log_obj, 'add tail pattern %s for ext %s',
- str, pattern.ext)
- elseif match.position < short_match_limit then
- short_patterns[#short_patterns + 1] = {
- str, match, pattern
- }
- if str:sub(1, 1) == '^' then
- lua_util.debugm(N, log_obj, 'add head pattern %s for ext %s',
- str, pattern.ext)
- else
- lua_util.debugm(N, log_obj, 'add short pattern %s for ext %s',
- str, pattern.ext)
- end
-
- if max_short_offset < match.position then
- max_short_offset = match.position
- end
- else
- processed_patterns[#processed_patterns + 1] = {
- str, match, pattern
- }
-
- lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
- str, pattern.ext)
- end
- else
- processed_patterns[#processed_patterns + 1] = {
- str, match, pattern
- }
-
- lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
- str, pattern.ext)
- end
- end
-
- if not compiled_patterns then
- for ext, pattern in pairs(patterns) do
- assert(types[ext], 'not found type: ' .. ext)
- pattern.ext = ext
- for _, match in ipairs(pattern.matches) do
- if match.string then
- if match.relative_position and not match.position then
- match.position = match.relative_position + #match.string
-
- if match.relative_position == 0 then
- if match.string:sub(1, 1) ~= '^' then
- match.string = '^' .. match.string
- end
- end
- end
- add_processed(match.string, match, pattern)
- elseif match.hex then
- local hex_table = {}
-
- for i = 1, #match.hex, 2 do
- local subc = match.hex:sub(i, i + 1)
- hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
- end
-
- if match.relative_position and not match.position then
- match.position = match.relative_position + #match.hex / 2
- end
- if match.relative_position == 0 then
- table.insert(hex_table, 1, '^')
- end
- add_processed(table.concat(hex_table), match, pattern)
- end
- end
- end
- local bit = require "bit"
- local compile_flags = bit.bor(rspamd_trie.flags.re, rspamd_trie.flags.dot_all)
- compile_flags = bit.bor(compile_flags, rspamd_trie.flags.single_match)
- compile_flags = bit.bor(compile_flags, rspamd_trie.flags.no_start)
- compiled_patterns = rspamd_trie.create(fun.totable(
- fun.map(function(t)
- return t[1]
- end, processed_patterns)),
- compile_flags
- )
- compiled_short_patterns = rspamd_trie.create(fun.totable(
- fun.map(function(t)
- return t[1]
- end, short_patterns)),
- compile_flags
- )
- compiled_tail_patterns = rspamd_trie.create(fun.totable(
- fun.map(function(t)
- return t[1]
- end, tail_patterns)),
- compile_flags
- )
-
- lua_util.debugm(N, log_obj,
- 'compiled %s (%s short; %s long; %s tail) patterns',
- #processed_patterns + #short_patterns + #tail_patterns,
- #short_patterns, #processed_patterns, #tail_patterns)
- end
- end
-
- process_patterns(rspamd_config)
-
- local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res, part)
- local matches = trie:match(chunk)
-
- local last = tlen
-
- local function add_result(weight, ext)
- if not res[ext] then
- res[ext] = 0
- end
- if weight then
- res[ext] = res[ext] + weight
- else
- res[ext] = res[ext] + 1
- end
-
- lua_util.debugm(N, log_obj, 'add pattern for %s, weight %s, total weight %s',
- ext, weight, res[ext])
- end
-
- local function match_position(pos, expected)
- local cmp = function(a, b)
- return a == b
- end
- if type(expected) == 'table' then
- -- Something like {'>', 0}
- if expected[1] == '>' then
- cmp = function(a, b)
- return a > b
- end
- elseif expected[1] == '>=' then
- cmp = function(a, b)
- return a >= b
- end
- elseif expected[1] == '<' then
- cmp = function(a, b)
- return a < b
- end
- elseif expected[1] == '<=' then
- cmp = function(a, b)
- return a <= b
- end
- elseif expected[1] == '!=' then
- cmp = function(a, b)
- return a ~= b
- end
- end
- expected = expected[2]
- end
-
- -- Tail match
- if expected < 0 then
- expected = last + expected + 1
- end
- return cmp(pos, expected)
- end
-
- for npat, matched_positions in pairs(matches) do
- local pat_data = processed_tbl[npat]
- local pattern = pat_data[3]
- local match = pat_data[2]
-
- -- Single position
- if match.position then
- local position = match.position
-
- for _, pos in ipairs(matched_positions) do
- lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
- pattern.ext, pos, offset)
- if match_position(pos + offset, position) then
- if match.heuristic then
- local ext, weight = match.heuristic(input, log_obj, pos + offset, part)
-
- if ext then
- add_result(weight, ext)
- break
- end
- else
- add_result(match.weight, pattern.ext)
- break
- end
- end
- end
- elseif match.positions then
- -- Match all positions
- local all_right = true
- local matched_pos = 0
- for _, position in ipairs(match.positions) do
- local matched = false
- for _, pos in ipairs(matched_positions) do
- lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
- pattern.ext, pos, offset)
- if not match_position(pos + offset, position) then
- matched = true
- matched_pos = pos
- break
- end
- end
- if not matched then
- all_right = false
- break
- end
- end
-
- if all_right then
- if match.heuristic then
- local ext, weight = match.heuristic(input, log_obj, matched_pos + offset, part)
-
- if ext then
- add_result(weight, ext)
- break
- end
- else
- add_result(match.weight, pattern.ext)
- break
- end
- end
- end
- end
-
- end
-
- local function process_detected(res)
- local extensions = lua_util.keys(res)
-
- if #extensions > 0 then
- table.sort(extensions, function(ex1, ex2)
- return res[ex1] > res[ex2]
- end)
-
- return extensions, res[extensions[1]]
- end
-
- return nil
- end
-
- exports.detect = function(part, log_obj)
- if not log_obj then
- log_obj = rspamd_config
- end
- local input = part:get_content()
-
- local res = {}
-
- if type(input) == 'string' then
- -- Convert to rspamd_text
- input = rspamd_text.fromstring(input)
- end
-
- if type(input) == 'userdata' then
- local inplen = #input
-
- -- Check tail matches
- if inplen > min_tail_offset then
- local tail = input:span(inplen - min_tail_offset, min_tail_offset)
- match_chunk(tail, input, inplen, inplen - min_tail_offset,
- compiled_tail_patterns, tail_patterns, log_obj, res, part)
- end
-
- -- Try short match
- local head = input:span(1, math.min(max_short_offset, inplen))
- match_chunk(head, input, inplen, 0,
- compiled_short_patterns, short_patterns, log_obj, res, part)
-
- -- Check if we have enough data or go to long patterns
- local extensions, confidence = process_detected(res)
-
- if extensions and #extensions > 0 and confidence > 30 then
- -- We are done on short patterns
- return extensions[1], types[extensions[1]]
- end
-
- -- No way, let's check data in chunks or just the whole input if it is small enough
- if #input > exports.chunk_size * 3 then
- -- Chunked version as input is too long
- local chunk1, chunk2 = input:span(1, exports.chunk_size * 2),
- input:span(inplen - exports.chunk_size, exports.chunk_size)
- local offset1, offset2 = 0, inplen - exports.chunk_size
-
- match_chunk(chunk1, input, inplen,
- offset1, compiled_patterns, processed_patterns, log_obj, res, part)
- match_chunk(chunk2, input, inplen,
- offset2, compiled_patterns, processed_patterns, log_obj, res, part)
- else
- -- Input is short enough to match it at all
- match_chunk(input, input, inplen, 0,
- compiled_patterns, processed_patterns, log_obj, res, part)
- end
- else
- -- Table input is NYI
- assert(0, 'table input for match')
- end
-
- local extensions = process_detected(res)
-
- if extensions and #extensions > 0 then
- return extensions[1], types[extensions[1]]
- end
-
- -- Nothing found
- return nil
- end
-
- exports.detect_mime_part = function(part, log_obj)
- local ext, weight = heuristics.mime_part_heuristic(part, log_obj)
-
- if ext and weight and weight > 20 then
- return ext, types[ext]
- end
-
- ext = exports.detect(part, log_obj)
-
- if ext then
- return ext, types[ext]
- end
-
- -- Text/html and other parts
- ext, weight = heuristics.text_part_heuristic(part, log_obj)
- if ext and weight and weight > 20 then
- return ext, types[ext]
- end
- end
-
- -- This parameter specifies how many bytes are checked in the input
- -- Rspamd checks 2 chunks at start and 1 chunk at the end
- exports.chunk_size = 32768
-
- exports.types = types
-
- return exports
|