You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_trie.c 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "message.h"
  18. #include "libutil/multipattern.h"
  19. /***
  20. * @module rspamd_trie
  21. * Rspamd trie module provides the data structure suitable for searching of many
  22. * patterns in arbitrary texts (or binary chunks). The algorithmic complexity of
  23. * this algorithm is at most O(n + m + z), where `n` is the length of text, `m` is a length of pattern and `z` is a number of patterns in the text.
  24. *
  25. * Here is a typical example of trie usage:
  26. * @example
  27. local rspamd_trie = require "rspamd_trie"
  28. local patterns = {'aab', 'ab', 'bcd\0ef'}
  29. local trie = rspamd_trie.create(patterns)
  30. local function trie_callback(number, pos)
  31. print('Matched pattern number ' .. tostring(number) .. ' at pos: ' .. tostring(pos))
  32. end
  33. trie:match('some big text', trie_callback)
  34. */
  35. /* Suffix trie */
  36. LUA_FUNCTION_DEF(trie, create);
  37. LUA_FUNCTION_DEF(trie, has_hyperscan);
  38. LUA_FUNCTION_DEF(trie, match);
  39. LUA_FUNCTION_DEF(trie, search_mime);
  40. LUA_FUNCTION_DEF(trie, search_rawmsg);
  41. LUA_FUNCTION_DEF(trie, search_rawbody);
  42. LUA_FUNCTION_DEF(trie, destroy);
  43. static const struct luaL_reg trielib_m[] = {
  44. LUA_INTERFACE_DEF(trie, match),
  45. LUA_INTERFACE_DEF(trie, search_mime),
  46. LUA_INTERFACE_DEF(trie, search_rawmsg),
  47. LUA_INTERFACE_DEF(trie, search_rawbody),
  48. {"__tostring", rspamd_lua_class_tostring},
  49. {"__gc", lua_trie_destroy},
  50. {NULL, NULL}};
  51. static const struct luaL_reg trielib_f[] = {
  52. LUA_INTERFACE_DEF(trie, create),
  53. LUA_INTERFACE_DEF(trie, has_hyperscan),
  54. {NULL, NULL}};
  55. static struct rspamd_multipattern *
  56. lua_check_trie(lua_State *L, int idx)
  57. {
  58. void *ud = rspamd_lua_check_udata(L, 1, rspamd_trie_classname);
  59. luaL_argcheck(L, ud != NULL, 1, "'trie' expected");
  60. return ud ? *((struct rspamd_multipattern **) ud) : NULL;
  61. }
  62. static int
  63. lua_trie_destroy(lua_State *L)
  64. {
  65. struct rspamd_multipattern *trie = lua_check_trie(L, 1);
  66. if (trie) {
  67. rspamd_multipattern_destroy(trie);
  68. }
  69. return 0;
  70. }
  71. /***
  72. * function trie.has_hyperscan()
  73. * Checks for hyperscan support
  74. *
  75. * @return {bool} true if hyperscan is supported
  76. */
  77. static int
  78. lua_trie_has_hyperscan(lua_State *L)
  79. {
  80. lua_pushboolean(L, rspamd_multipattern_has_hyperscan());
  81. return 1;
  82. }
  83. /***
  84. * function trie.create(patterns, [flags])
  85. * Creates new trie data structure
  86. * @param {table} array of string patterns
  87. * @return {trie} new trie object
  88. */
  89. static int
  90. lua_trie_create(lua_State *L)
  91. {
  92. struct rspamd_multipattern *trie, **ptrie;
  93. int npat = 0, flags = RSPAMD_MULTIPATTERN_ICASE | RSPAMD_MULTIPATTERN_GLOB;
  94. GError *err = NULL;
  95. if (lua_isnumber(L, 2)) {
  96. flags = lua_tointeger(L, 2);
  97. }
  98. if (!lua_istable(L, 1)) {
  99. return luaL_error(L, "lua trie expects array of patterns for now");
  100. }
  101. else {
  102. lua_pushvalue(L, 1);
  103. lua_pushnil(L);
  104. while (lua_next(L, -2) != 0) {
  105. if (lua_isstring(L, -1)) {
  106. npat++;
  107. }
  108. lua_pop(L, 1);
  109. }
  110. trie = rspamd_multipattern_create_sized(npat, flags);
  111. lua_pushnil(L);
  112. while (lua_next(L, -2) != 0) {
  113. if (lua_isstring(L, -1)) {
  114. const char *pat;
  115. gsize patlen;
  116. pat = lua_tolstring(L, -1, &patlen);
  117. rspamd_multipattern_add_pattern_len(trie, pat, patlen, flags);
  118. }
  119. lua_pop(L, 1);
  120. }
  121. lua_pop(L, 1); /* table */
  122. if (!rspamd_multipattern_compile(trie, 0, &err)) {
  123. msg_err("cannot compile multipattern: %e", err);
  124. g_error_free(err);
  125. rspamd_multipattern_destroy(trie);
  126. lua_pushnil(L);
  127. }
  128. else {
  129. ptrie = lua_newuserdata(L, sizeof(void *));
  130. rspamd_lua_setclass(L, rspamd_trie_classname, -1);
  131. *ptrie = trie;
  132. }
  133. }
  134. return 1;
  135. }
  136. #define PUSH_TRIE_MATCH(L, start, end, report_start) \
  137. do { \
  138. if (report_start) { \
  139. lua_createtable(L, 2, 0); \
  140. lua_pushinteger(L, (start)); \
  141. lua_rawseti(L, -2, 1); \
  142. lua_pushinteger(L, (end)); \
  143. lua_rawseti(L, -2, 2); \
  144. } \
  145. else { \
  146. lua_pushinteger(L, (end)); \
  147. } \
  148. } while (0)
  149. /* Normal callback type */
  150. static int
  151. lua_trie_lua_cb_callback(struct rspamd_multipattern *mp,
  152. unsigned int strnum,
  153. int match_start,
  154. int textpos,
  155. const char *text,
  156. gsize len,
  157. void *context)
  158. {
  159. lua_State *L = context;
  160. int ret;
  161. gboolean report_start = lua_toboolean(L, -1);
  162. /* Function */
  163. lua_pushvalue(L, 3);
  164. lua_pushinteger(L, strnum + 1);
  165. PUSH_TRIE_MATCH(L, match_start, textpos, report_start);
  166. if (lua_pcall(L, 2, 1, 0) != 0) {
  167. msg_info("call to trie callback has failed: %s",
  168. lua_tostring(L, -1));
  169. lua_pop(L, 1);
  170. return 1;
  171. }
  172. ret = lua_tonumber(L, -1);
  173. lua_pop(L, 1);
  174. return ret;
  175. }
  176. /* Table like callback, expect result table on top of the stack */
  177. static int
  178. lua_trie_table_callback(struct rspamd_multipattern *mp,
  179. unsigned int strnum,
  180. int match_start,
  181. int textpos,
  182. const char *text,
  183. gsize len,
  184. void *context)
  185. {
  186. lua_State *L = context;
  187. int report_start = lua_toboolean(L, -2);
  188. /* Set table, indexed by pattern number */
  189. lua_rawgeti(L, -1, strnum + 1);
  190. if (lua_istable(L, -1)) {
  191. /* Already have table, add offset */
  192. gsize last = rspamd_lua_table_size(L, -1);
  193. PUSH_TRIE_MATCH(L, match_start, textpos, report_start);
  194. lua_rawseti(L, -2, last + 1);
  195. /* Remove table from the stack */
  196. lua_pop(L, 1);
  197. }
  198. else {
  199. /* Pop none */
  200. lua_pop(L, 1);
  201. /* New table */
  202. lua_newtable(L);
  203. PUSH_TRIE_MATCH(L, match_start, textpos, report_start);
  204. lua_rawseti(L, -2, 1);
  205. lua_rawseti(L, -2, strnum + 1);
  206. }
  207. return 0;
  208. }
  209. /*
  210. * We assume that callback argument is at pos 3 and icase is in position 4
  211. */
  212. static int
  213. lua_trie_search_str(lua_State *L, struct rspamd_multipattern *trie,
  214. const char *str, gsize len, rspamd_multipattern_cb_t cb)
  215. {
  216. int ret;
  217. unsigned int nfound = 0;
  218. if ((ret = rspamd_multipattern_lookup(trie, str, len,
  219. cb, L, &nfound)) == 0) {
  220. return nfound;
  221. }
  222. return ret;
  223. }
  224. /***
  225. * @method trie:match(input, [cb][, report_start])
  226. * Search for patterns in `input` invoking `cb` optionally ignoring case
  227. * @param {table or string} input one or several (if `input` is an array) strings of input text
  228. * @param {function} cb callback called on each pattern match in form `function (idx, pos)` where `idx` is a numeric index of pattern (starting from 1) and `pos` is a numeric offset where the pattern ends
  229. * @param {boolean} report_start report both start and end offset when matching patterns
  230. * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however). If `cb` is not defined then it returns a table of match positions indexed by pattern number
  231. */
  232. static int
  233. lua_trie_match(lua_State *L)
  234. {
  235. LUA_TRACE_POINT;
  236. struct rspamd_multipattern *trie = lua_check_trie(L, 1);
  237. const char *text;
  238. gsize len;
  239. gboolean found = FALSE, report_start = FALSE;
  240. struct rspamd_lua_text *t;
  241. rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
  242. int old_top = lua_gettop(L);
  243. if (trie) {
  244. if (lua_type(L, 3) != LUA_TFUNCTION) {
  245. if (lua_isboolean(L, 3)) {
  246. report_start = lua_toboolean(L, 3);
  247. }
  248. lua_pushboolean(L, report_start);
  249. /* Table like match */
  250. lua_newtable(L);
  251. cb = lua_trie_table_callback;
  252. }
  253. else {
  254. if (lua_isboolean(L, 4)) {
  255. report_start = lua_toboolean(L, 4);
  256. }
  257. lua_pushboolean(L, report_start);
  258. }
  259. if (lua_type(L, 2) == LUA_TTABLE) {
  260. lua_pushvalue(L, 2);
  261. lua_pushnil(L);
  262. while (lua_next(L, -2) != 0) {
  263. if (lua_isstring(L, -1)) {
  264. text = lua_tolstring(L, -1, &len);
  265. if (lua_trie_search_str(L, trie, text, len, cb)) {
  266. found = TRUE;
  267. }
  268. }
  269. else if (lua_isuserdata(L, -1)) {
  270. t = lua_check_text(L, -1);
  271. if (t) {
  272. if (lua_trie_search_str(L, trie, t->start, t->len, cb)) {
  273. found = TRUE;
  274. }
  275. }
  276. }
  277. lua_pop(L, 1);
  278. }
  279. }
  280. else if (lua_type(L, 2) == LUA_TSTRING) {
  281. text = lua_tolstring(L, 2, &len);
  282. if (lua_trie_search_str(L, trie, text, len, cb)) {
  283. found = TRUE;
  284. }
  285. }
  286. else if (lua_type(L, 2) == LUA_TUSERDATA) {
  287. t = lua_check_text(L, 2);
  288. if (t && lua_trie_search_str(L, trie, t->start, t->len, cb)) {
  289. found = TRUE;
  290. }
  291. }
  292. }
  293. if (lua_type(L, 3) == LUA_TFUNCTION) {
  294. lua_settop(L, old_top);
  295. lua_pushboolean(L, found);
  296. }
  297. else {
  298. lua_remove(L, -2);
  299. }
  300. return 1;
  301. }
  302. /***
  303. * @method trie:search_mime(task, cb)
  304. * This is a helper mehthod to search pattern within text parts of a message in rspamd task
  305. * @param {task} task object
  306. * @param {function} cb callback called on each pattern match @see trie:match
  307. * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
  308. * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
  309. */
  310. static int
  311. lua_trie_search_mime(lua_State *L)
  312. {
  313. LUA_TRACE_POINT;
  314. struct rspamd_multipattern *trie = lua_check_trie(L, 1);
  315. struct rspamd_task *task = lua_check_task(L, 2);
  316. struct rspamd_mime_text_part *part;
  317. const char *text;
  318. gsize len, i;
  319. gboolean found = FALSE;
  320. rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
  321. if (trie && task) {
  322. PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
  323. {
  324. if (!IS_TEXT_PART_EMPTY(part) && part->utf_content.len > 0) {
  325. text = part->utf_content.begin;
  326. len = part->utf_content.len;
  327. if (lua_trie_search_str(L, trie, text, len, cb) != 0) {
  328. found = TRUE;
  329. }
  330. }
  331. }
  332. }
  333. lua_pushboolean(L, found);
  334. return 1;
  335. }
  336. /***
  337. * @method trie:search_rawmsg(task, cb[, caseless])
  338. * This is a helper mehthod to search pattern within the whole undecoded content of rspamd task
  339. * @param {task} task object
  340. * @param {function} cb callback called on each pattern match @see trie:match
  341. * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
  342. * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
  343. */
  344. static int
  345. lua_trie_search_rawmsg(lua_State *L)
  346. {
  347. LUA_TRACE_POINT;
  348. struct rspamd_multipattern *trie = lua_check_trie(L, 1);
  349. struct rspamd_task *task = lua_check_task(L, 2);
  350. const char *text;
  351. gsize len;
  352. gboolean found = FALSE;
  353. if (trie && task) {
  354. text = task->msg.begin;
  355. len = task->msg.len;
  356. if (lua_trie_search_str(L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
  357. found = TRUE;
  358. }
  359. }
  360. lua_pushboolean(L, found);
  361. return 1;
  362. }
  363. /***
  364. * @method trie:search_rawbody(task, cb[, caseless])
  365. * This is a helper mehthod to search pattern within the whole undecoded content of task's body (not including headers)
  366. * @param {task} task object
  367. * @param {function} cb callback called on each pattern match @see trie:match
  368. * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
  369. * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
  370. */
  371. static int
  372. lua_trie_search_rawbody(lua_State *L)
  373. {
  374. LUA_TRACE_POINT;
  375. struct rspamd_multipattern *trie = lua_check_trie(L, 1);
  376. struct rspamd_task *task = lua_check_task(L, 2);
  377. const char *text;
  378. gsize len;
  379. gboolean found = FALSE;
  380. if (trie && task) {
  381. if (MESSAGE_FIELD(task, raw_headers_content).len > 0) {
  382. text = task->msg.begin + MESSAGE_FIELD(task, raw_headers_content).len;
  383. len = task->msg.len - MESSAGE_FIELD(task, raw_headers_content).len;
  384. }
  385. else {
  386. /* Treat as raw message */
  387. text = task->msg.begin;
  388. len = task->msg.len;
  389. }
  390. if (lua_trie_search_str(L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
  391. found = TRUE;
  392. }
  393. }
  394. lua_pushboolean(L, found);
  395. return 1;
  396. }
  397. static int
  398. lua_load_trie(lua_State *L)
  399. {
  400. lua_newtable(L);
  401. /* Flags */
  402. lua_pushstring(L, "flags");
  403. lua_newtable(L);
  404. lua_pushinteger(L, RSPAMD_MULTIPATTERN_GLOB);
  405. lua_setfield(L, -2, "glob");
  406. lua_pushinteger(L, RSPAMD_MULTIPATTERN_RE);
  407. lua_setfield(L, -2, "re");
  408. lua_pushinteger(L, RSPAMD_MULTIPATTERN_ICASE);
  409. lua_setfield(L, -2, "icase");
  410. lua_pushinteger(L, RSPAMD_MULTIPATTERN_UTF8);
  411. lua_setfield(L, -2, "utf8");
  412. lua_pushinteger(L, RSPAMD_MULTIPATTERN_TLD);
  413. lua_setfield(L, -2, "tld");
  414. lua_pushinteger(L, RSPAMD_MULTIPATTERN_DOTALL);
  415. lua_setfield(L, -2, "dot_all");
  416. lua_pushinteger(L, RSPAMD_MULTIPATTERN_SINGLEMATCH);
  417. lua_setfield(L, -2, "single_match");
  418. lua_pushinteger(L, RSPAMD_MULTIPATTERN_NO_START);
  419. lua_setfield(L, -2, "no_start");
  420. lua_settable(L, -3);
  421. /* Main content */
  422. luaL_register(L, NULL, trielib_f);
  423. return 1;
  424. }
  425. void luaopen_trie(lua_State *L)
  426. {
  427. rspamd_lua_new_class(L, rspamd_trie_classname, trielib_m);
  428. lua_pop(L, 1);
  429. rspamd_lua_add_preload(L, "rspamd_trie", lua_load_trie);
  430. }