You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

init.lua 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. --[[
  2. Copyright (c) 2019, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_magic
  15. -- This module contains file types detection logic
  16. --]]
  17. local patterns = require "lua_magic/patterns"
  18. local types = require "lua_magic/types"
  19. local heuristics = require "lua_magic/heuristics"
  20. local fun = require "fun"
  21. local lua_util = require "lua_util"
  22. local rspamd_text = require "rspamd_text"
  23. local rspamd_trie = require "rspamd_trie"
  24. local N = "lua_magic"
  25. local exports = {}
  26. -- trie objects
  27. local compiled_patterns
  28. local compiled_short_patterns
  29. local compiled_tail_patterns
  30. -- {<str>, <match_object>, <pattern_object>} indexed by pattern number
  31. local processed_patterns = {}
  32. local short_patterns = {}
  33. local tail_patterns = {}
  34. local short_match_limit = 128
  35. local max_short_offset = -1
  36. local min_tail_offset = math.huge
  37. local function process_patterns(log_obj)
  38. -- Add pattern to either short patterns or to normal patterns
  39. local function add_processed(str, match, pattern)
  40. if match.position and type(match.position) == 'number' then
  41. if match.tail then
  42. -- Tail pattern
  43. tail_patterns[#tail_patterns + 1] = {
  44. str, match, pattern
  45. }
  46. if min_tail_offset > match.tail then
  47. min_tail_offset = match.tail
  48. end
  49. lua_util.debugm(N, log_obj, 'add tail pattern %s for ext %s',
  50. str, pattern.ext)
  51. elseif match.position < short_match_limit then
  52. short_patterns[#short_patterns + 1] = {
  53. str, match, pattern
  54. }
  55. if str:sub(1, 1) == '^' then
  56. lua_util.debugm(N, log_obj, 'add head pattern %s for ext %s',
  57. str, pattern.ext)
  58. else
  59. lua_util.debugm(N, log_obj, 'add short pattern %s for ext %s',
  60. str, pattern.ext)
  61. end
  62. if max_short_offset < match.position then
  63. max_short_offset = match.position
  64. end
  65. else
  66. processed_patterns[#processed_patterns + 1] = {
  67. str, match, pattern
  68. }
  69. lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
  70. str, pattern.ext)
  71. end
  72. else
  73. processed_patterns[#processed_patterns + 1] = {
  74. str, match, pattern
  75. }
  76. lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
  77. str, pattern.ext)
  78. end
  79. end
  80. if not compiled_patterns then
  81. for ext,pattern in pairs(patterns) do
  82. assert(types[ext], 'not found type: ' .. ext)
  83. pattern.ext = ext
  84. for _,match in ipairs(pattern.matches) do
  85. if match.string then
  86. if match.relative_position and not match.position then
  87. match.position = match.relative_position + #match.string
  88. if match.relative_position == 0 then
  89. if match.string:sub(1, 1) ~= '^' then
  90. match.string = '^' .. match.string
  91. end
  92. end
  93. end
  94. add_processed(match.string, match, pattern)
  95. elseif match.hex then
  96. local hex_table = {}
  97. for i=1,#match.hex,2 do
  98. local subc = match.hex:sub(i, i + 1)
  99. hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
  100. end
  101. if match.relative_position and not match.position then
  102. match.position = match.relative_position + #match.hex / 2
  103. end
  104. if match.relative_position == 0 then
  105. table.insert(hex_table, 1, '^')
  106. end
  107. add_processed(table.concat(hex_table), match, pattern)
  108. end
  109. end
  110. end
  111. local bit = require "bit"
  112. local compile_flags = bit.bor(rspamd_trie.flags.re, rspamd_trie.flags.dot_all)
  113. compile_flags = bit.bor(compile_flags, rspamd_trie.flags.single_match)
  114. compile_flags = bit.bor(compile_flags, rspamd_trie.flags.no_start)
  115. compiled_patterns = rspamd_trie.create(fun.totable(
  116. fun.map(function(t) return t[1] end, processed_patterns)),
  117. compile_flags
  118. )
  119. compiled_short_patterns = rspamd_trie.create(fun.totable(
  120. fun.map(function(t) return t[1] end, short_patterns)),
  121. compile_flags
  122. )
  123. compiled_tail_patterns = rspamd_trie.create(fun.totable(
  124. fun.map(function(t) return t[1] end, tail_patterns)),
  125. compile_flags
  126. )
  127. lua_util.debugm(N, log_obj,
  128. 'compiled %s (%s short; %s long; %s tail) patterns',
  129. #processed_patterns + #short_patterns + #tail_patterns,
  130. #short_patterns, #processed_patterns, #tail_patterns)
  131. end
  132. end
  133. process_patterns(rspamd_config)
  134. local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res)
  135. local matches = trie:match(chunk)
  136. local last = tlen
  137. local function add_result(weight, ext)
  138. if not res[ext] then
  139. res[ext] = 0
  140. end
  141. if weight then
  142. res[ext] = res[ext] + weight
  143. else
  144. res[ext] = res[ext] + 1
  145. end
  146. lua_util.debugm(N, log_obj,'add pattern for %s, weight %s, total weight %s',
  147. ext, weight, res[ext])
  148. end
  149. local function match_position(pos, expected)
  150. local cmp = function(a, b) return a == b end
  151. if type(expected) == 'table' then
  152. -- Something like {'>', 0}
  153. if expected[1] == '>' then
  154. cmp = function(a, b) return a > b end
  155. elseif expected[1] == '>=' then
  156. cmp = function(a, b) return a >= b end
  157. elseif expected[1] == '<' then
  158. cmp = function(a, b) return a < b end
  159. elseif expected[1] == '<=' then
  160. cmp = function(a, b) return a <= b end
  161. elseif expected[1] == '!=' then
  162. cmp = function(a, b) return a ~= b end
  163. end
  164. expected = expected[2]
  165. end
  166. -- Tail match
  167. if expected < 0 then
  168. expected = last + expected + 1
  169. end
  170. return cmp(pos, expected)
  171. end
  172. for npat,matched_positions in pairs(matches) do
  173. local pat_data = processed_tbl[npat]
  174. local pattern = pat_data[3]
  175. local match = pat_data[2]
  176. -- Single position
  177. if match.position then
  178. local position = match.position
  179. for _,pos in ipairs(matched_positions) do
  180. lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
  181. pattern.ext, pos, offset)
  182. if match_position(pos + offset, position) then
  183. if match.heuristic then
  184. local ext,weight = match.heuristic(input, log_obj)
  185. if ext then
  186. add_result(weight, ext)
  187. break
  188. end
  189. else
  190. add_result(match.weight, pattern.ext)
  191. break
  192. end
  193. end
  194. end
  195. elseif match.positions then
  196. -- Match all positions
  197. local all_right = true
  198. for _,position in ipairs(match.positions) do
  199. local matched = false
  200. for _,pos in ipairs(matched_positions) do
  201. lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
  202. pattern.ext, pos, offset)
  203. if not match_position(pos + offset, position) then
  204. matched = true
  205. break
  206. end
  207. end
  208. if not matched then
  209. all_right = false
  210. break
  211. end
  212. end
  213. if all_right then
  214. if match.heuristic then
  215. local ext,weight = match.heuristic(input, log_obj)
  216. if ext then
  217. add_result(weight, ext)
  218. break
  219. end
  220. else
  221. add_result(match.weight, pattern.ext)
  222. break
  223. end
  224. end
  225. end
  226. end
  227. end
  228. local function process_detected(res)
  229. local extensions = lua_util.keys(res)
  230. if #extensions > 0 then
  231. table.sort(extensions, function(ex1, ex2)
  232. return res[ex1] > res[ex2]
  233. end)
  234. return extensions,res[extensions[1]]
  235. end
  236. return nil
  237. end
  238. exports.detect = function(input, log_obj)
  239. if not log_obj then log_obj = rspamd_config end
  240. local res = {}
  241. if type(input) == 'string' then
  242. -- Convert to rspamd_text
  243. input = rspamd_text.fromstring(input)
  244. end
  245. if type(input) == 'userdata' then
  246. local inplen = #input
  247. -- Check tail matches
  248. if inplen > min_tail_offset then
  249. local tail = input:span(inplen - min_tail_offset, min_tail_offset)
  250. match_chunk(tail, input, inplen, inplen - min_tail_offset,
  251. compiled_tail_patterns, tail_patterns, log_obj, res)
  252. end
  253. -- Try short match
  254. local head = input:span(1, math.min(max_short_offset, inplen))
  255. match_chunk(head, input, inplen, 0,
  256. compiled_short_patterns, short_patterns, log_obj, res)
  257. -- Check if we have enough data or go to long patterns
  258. local extensions,confidence = process_detected(res)
  259. if extensions and #extensions > 0 and confidence > 30 then
  260. -- We are done on short patterns
  261. return extensions[1],types[extensions[1]]
  262. end
  263. -- No way, let's check data in chunks or just the whole input if it is small enough
  264. if #input > exports.chunk_size * 3 then
  265. -- Chunked version as input is too long
  266. local chunk1, chunk2 =
  267. input:span(1, exports.chunk_size * 2),
  268. input:span(inplen - exports.chunk_size, exports.chunk_size)
  269. local offset1, offset2 = 0, inplen - exports.chunk_size
  270. match_chunk(chunk1, input, inplen,
  271. offset1, compiled_patterns, processed_patterns, log_obj, res)
  272. match_chunk(chunk2, input, inplen,
  273. offset2, compiled_patterns, processed_patterns, log_obj, res)
  274. else
  275. -- Input is short enough to match it at all
  276. match_chunk(input, input, inplen, 0,
  277. compiled_patterns, processed_patterns, log_obj, res)
  278. end
  279. else
  280. -- Table input is NYI
  281. assert(0, 'table input for match')
  282. end
  283. local extensions = process_detected(res)
  284. if extensions and #extensions > 0 then
  285. return extensions[1],types[extensions[1]]
  286. end
  287. -- Nothing found
  288. return nil
  289. end
  290. exports.detect_mime_part = function(part, log_obj)
  291. local ext,weight = heuristics.mime_part_heuristic(part, log_obj)
  292. if ext and weight and weight > 20 then
  293. return ext,types[ext]
  294. end
  295. ext = exports.detect(part:get_content(), log_obj)
  296. if ext then
  297. return ext,types[ext]
  298. end
  299. -- Text/html and other parts
  300. ext,weight = heuristics.text_part_heuristic(part, log_obj)
  301. if ext and weight and weight > 20 then
  302. return ext,types[ext]
  303. end
  304. end
  305. -- This parameter specifies how many bytes are checked in the input
  306. -- Rspamd checks 2 chunks at start and 1 chunk at the end
  307. exports.chunk_size = 32768
  308. exports.types = types
  309. return exports