You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_magic
  15. -- This module contains file types detection logic
  16. --]]
  17. local patterns = require "lua_magic/patterns"
  18. local types = require "lua_magic/types"
  19. local heuristics = require "lua_magic/heuristics"
  20. local fun = require "fun"
  21. local lua_util = require "lua_util"
  22. local rspamd_text = require "rspamd_text"
  23. local rspamd_trie = require "rspamd_trie"
  24. local N = "lua_magic"
  25. local exports = {}
  26. -- trie objects
  27. local compiled_patterns
  28. local compiled_short_patterns
  29. local compiled_tail_patterns
  30. -- {<str>, <match_object>, <pattern_object>} indexed by pattern number
  31. local processed_patterns = {}
  32. local short_patterns = {}
  33. local tail_patterns = {}
  34. local short_match_limit = 128
  35. local max_short_offset = -1
  36. local min_tail_offset = math.huge
  37. local function process_patterns(log_obj)
  38. -- Add pattern to either short patterns or to normal patterns
  39. local function add_processed(str, match, pattern)
  40. if match.position and type(match.position) == 'number' then
  41. if match.tail then
  42. -- Tail pattern
  43. tail_patterns[#tail_patterns + 1] = {
  44. str, match, pattern
  45. }
  46. if min_tail_offset > match.tail then
  47. min_tail_offset = match.tail
  48. end
  49. lua_util.debugm(N, log_obj, 'add tail pattern %s for ext %s',
  50. str, pattern.ext)
  51. elseif match.position < short_match_limit then
  52. short_patterns[#short_patterns + 1] = {
  53. str, match, pattern
  54. }
  55. if str:sub(1, 1) == '^' then
  56. lua_util.debugm(N, log_obj, 'add head pattern %s for ext %s',
  57. str, pattern.ext)
  58. else
  59. lua_util.debugm(N, log_obj, 'add short pattern %s for ext %s',
  60. str, pattern.ext)
  61. end
  62. if max_short_offset < match.position then
  63. max_short_offset = match.position
  64. end
  65. else
  66. processed_patterns[#processed_patterns + 1] = {
  67. str, match, pattern
  68. }
  69. lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
  70. str, pattern.ext)
  71. end
  72. else
  73. processed_patterns[#processed_patterns + 1] = {
  74. str, match, pattern
  75. }
  76. lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
  77. str, pattern.ext)
  78. end
  79. end
  80. if not compiled_patterns then
  81. for ext, pattern in pairs(patterns) do
  82. assert(types[ext], 'not found type: ' .. ext)
  83. pattern.ext = ext
  84. for _, match in ipairs(pattern.matches) do
  85. if match.string then
  86. if match.relative_position and not match.position then
  87. match.position = match.relative_position + #match.string
  88. if match.relative_position == 0 then
  89. if match.string:sub(1, 1) ~= '^' then
  90. match.string = '^' .. match.string
  91. end
  92. end
  93. end
  94. add_processed(match.string, match, pattern)
  95. elseif match.hex then
  96. local hex_table = {}
  97. for i = 1, #match.hex, 2 do
  98. local subc = match.hex:sub(i, i + 1)
  99. hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
  100. end
  101. if match.relative_position and not match.position then
  102. match.position = match.relative_position + #match.hex / 2
  103. end
  104. if match.relative_position == 0 then
  105. table.insert(hex_table, 1, '^')
  106. end
  107. add_processed(table.concat(hex_table), match, pattern)
  108. end
  109. end
  110. end
  111. local bit = require "bit"
  112. local compile_flags = bit.bor(rspamd_trie.flags.re, rspamd_trie.flags.dot_all)
  113. compile_flags = bit.bor(compile_flags, rspamd_trie.flags.single_match)
  114. compile_flags = bit.bor(compile_flags, rspamd_trie.flags.no_start)
  115. compiled_patterns = rspamd_trie.create(fun.totable(
  116. fun.map(function(t)
  117. return t[1]
  118. end, processed_patterns)),
  119. compile_flags
  120. )
  121. compiled_short_patterns = rspamd_trie.create(fun.totable(
  122. fun.map(function(t)
  123. return t[1]
  124. end, short_patterns)),
  125. compile_flags
  126. )
  127. compiled_tail_patterns = rspamd_trie.create(fun.totable(
  128. fun.map(function(t)
  129. return t[1]
  130. end, tail_patterns)),
  131. compile_flags
  132. )
  133. lua_util.debugm(N, log_obj,
  134. 'compiled %s (%s short; %s long; %s tail) patterns',
  135. #processed_patterns + #short_patterns + #tail_patterns,
  136. #short_patterns, #processed_patterns, #tail_patterns)
  137. end
  138. end
  139. process_patterns(rspamd_config)
  140. local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res, part)
  141. local matches = trie:match(chunk)
  142. local last = tlen
  143. local function add_result(weight, ext)
  144. if not res[ext] then
  145. res[ext] = 0
  146. end
  147. if weight then
  148. res[ext] = res[ext] + weight
  149. else
  150. res[ext] = res[ext] + 1
  151. end
  152. lua_util.debugm(N, log_obj, 'add pattern for %s, weight %s, total weight %s',
  153. ext, weight, res[ext])
  154. end
  155. local function match_position(pos, expected)
  156. local cmp = function(a, b)
  157. return a == b
  158. end
  159. if type(expected) == 'table' then
  160. -- Something like {'>', 0}
  161. if expected[1] == '>' then
  162. cmp = function(a, b)
  163. return a > b
  164. end
  165. elseif expected[1] == '>=' then
  166. cmp = function(a, b)
  167. return a >= b
  168. end
  169. elseif expected[1] == '<' then
  170. cmp = function(a, b)
  171. return a < b
  172. end
  173. elseif expected[1] == '<=' then
  174. cmp = function(a, b)
  175. return a <= b
  176. end
  177. elseif expected[1] == '!=' then
  178. cmp = function(a, b)
  179. return a ~= b
  180. end
  181. end
  182. expected = expected[2]
  183. end
  184. -- Tail match
  185. if expected < 0 then
  186. expected = last + expected + 1
  187. end
  188. return cmp(pos, expected)
  189. end
  190. for npat, matched_positions in pairs(matches) do
  191. local pat_data = processed_tbl[npat]
  192. local pattern = pat_data[3]
  193. local match = pat_data[2]
  194. -- Single position
  195. if match.position then
  196. local position = match.position
  197. for _, pos in ipairs(matched_positions) do
  198. lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
  199. pattern.ext, pos, offset)
  200. if match_position(pos + offset, position) then
  201. if match.heuristic then
  202. local ext, weight = match.heuristic(input, log_obj, pos + offset, part)
  203. if ext then
  204. add_result(weight, ext)
  205. break
  206. end
  207. else
  208. add_result(match.weight, pattern.ext)
  209. break
  210. end
  211. end
  212. end
  213. elseif match.positions then
  214. -- Match all positions
  215. local all_right = true
  216. local matched_pos = 0
  217. for _, position in ipairs(match.positions) do
  218. local matched = false
  219. for _, pos in ipairs(matched_positions) do
  220. lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
  221. pattern.ext, pos, offset)
  222. if not match_position(pos + offset, position) then
  223. matched = true
  224. matched_pos = pos
  225. break
  226. end
  227. end
  228. if not matched then
  229. all_right = false
  230. break
  231. end
  232. end
  233. if all_right then
  234. if match.heuristic then
  235. local ext, weight = match.heuristic(input, log_obj, matched_pos + offset, part)
  236. if ext then
  237. add_result(weight, ext)
  238. break
  239. end
  240. else
  241. add_result(match.weight, pattern.ext)
  242. break
  243. end
  244. end
  245. end
  246. end
  247. end
  248. local function process_detected(res)
  249. local extensions = lua_util.keys(res)
  250. if #extensions > 0 then
  251. table.sort(extensions, function(ex1, ex2)
  252. return res[ex1] > res[ex2]
  253. end)
  254. return extensions, res[extensions[1]]
  255. end
  256. return nil
  257. end
  258. exports.detect = function(part, log_obj)
  259. if not log_obj then
  260. log_obj = rspamd_config
  261. end
  262. local input = part:get_content()
  263. local res = {}
  264. if type(input) == 'string' then
  265. -- Convert to rspamd_text
  266. input = rspamd_text.fromstring(input)
  267. end
  268. if type(input) == 'userdata' then
  269. local inplen = #input
  270. -- Check tail matches
  271. if inplen > min_tail_offset then
  272. local tail = input:span(inplen - min_tail_offset, min_tail_offset)
  273. match_chunk(tail, input, inplen, inplen - min_tail_offset,
  274. compiled_tail_patterns, tail_patterns, log_obj, res, part)
  275. end
  276. -- Try short match
  277. local head = input:span(1, math.min(max_short_offset, inplen))
  278. match_chunk(head, input, inplen, 0,
  279. compiled_short_patterns, short_patterns, log_obj, res, part)
  280. -- Check if we have enough data or go to long patterns
  281. local extensions, confidence = process_detected(res)
  282. if extensions and #extensions > 0 and confidence > 30 then
  283. -- We are done on short patterns
  284. return extensions[1], types[extensions[1]]
  285. end
  286. -- No way, let's check data in chunks or just the whole input if it is small enough
  287. if #input > exports.chunk_size * 3 then
  288. -- Chunked version as input is too long
  289. local chunk1, chunk2 = input:span(1, exports.chunk_size * 2),
  290. input:span(inplen - exports.chunk_size, exports.chunk_size)
  291. local offset1, offset2 = 0, inplen - exports.chunk_size
  292. match_chunk(chunk1, input, inplen,
  293. offset1, compiled_patterns, processed_patterns, log_obj, res, part)
  294. match_chunk(chunk2, input, inplen,
  295. offset2, compiled_patterns, processed_patterns, log_obj, res, part)
  296. else
  297. -- Input is short enough to match it at all
  298. match_chunk(input, input, inplen, 0,
  299. compiled_patterns, processed_patterns, log_obj, res, part)
  300. end
  301. else
  302. -- Table input is NYI
  303. assert(0, 'table input for match')
  304. end
  305. local extensions = process_detected(res)
  306. if extensions and #extensions > 0 then
  307. return extensions[1], types[extensions[1]]
  308. end
  309. -- Nothing found
  310. return nil
  311. end
  312. exports.detect_mime_part = function(part, log_obj)
  313. local ext, weight = heuristics.mime_part_heuristic(part, log_obj)
  314. if ext and weight and weight > 20 then
  315. return ext, types[ext]
  316. end
  317. ext = exports.detect(part, log_obj)
  318. if ext then
  319. return ext, types[ext]
  320. end
  321. -- Text/html and other parts
  322. ext, weight = heuristics.text_part_heuristic(part, log_obj)
  323. if ext and weight and weight > 20 then
  324. return ext, types[ext]
  325. end
  326. end
  327. -- This parameter specifies how many bytes are checked in the input
  328. -- Rspamd checks 2 chunks at start and 1 chunk at the end
  329. exports.chunk_size = 32768
  330. exports.types = types
  331. return exports