You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_compress.c 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. /*-
  2. * Copyright 2021 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "unix-std.h"
  18. #include <zlib.h>
  19. #ifdef SYS_ZSTD
  20. # include "zstd.h"
  21. # include "zstd_errors.h"
  22. #else
  23. # include "contrib/zstd/zstd.h"
  24. # include "contrib/zstd/error_public.h"
  25. #endif
  26. /***
  27. * @module rspamd_compress
  28. * This module contains compression/decompression routines (zstd and zlib currently)
  29. */
  30. /***
  31. * @function zstd.compress_ctx()
  32. * Creates new compression ctx
  33. * @return {compress_ctx} new compress ctx
  34. */
  35. LUA_FUNCTION_DEF (zstd, compress_ctx);
  36. /***
  37. * @function zstd.compress_ctx()
  38. * Creates new compression ctx
  39. * @return {compress_ctx} new compress ctx
  40. */
  41. LUA_FUNCTION_DEF (zstd, decompress_ctx);
  42. LUA_FUNCTION_DEF (zstd_compress, stream);
  43. LUA_FUNCTION_DEF (zstd_compress, dtor);
  44. LUA_FUNCTION_DEF (zstd_decompress, stream);
  45. LUA_FUNCTION_DEF (zstd_decompress, dtor);
  46. static const struct luaL_reg zstd_compress_lib_f[] = {
  47. LUA_INTERFACE_DEF (zstd, compress_ctx),
  48. LUA_INTERFACE_DEF (zstd, decompress_ctx),
  49. {NULL, NULL}
  50. };
  51. static const struct luaL_reg zstd_compress_lib_m[] = {
  52. LUA_INTERFACE_DEF (zstd_compress, stream),
  53. {"__gc", lua_zstd_compress_dtor},
  54. {NULL, NULL}
  55. };
  56. static const struct luaL_reg zstd_decompress_lib_m[] = {
  57. LUA_INTERFACE_DEF (zstd_decompress, stream),
  58. {"__gc", lua_zstd_decompress_dtor},
  59. {NULL, NULL}
  60. };
  61. static ZSTD_CStream *
  62. lua_check_zstd_compress_ctx (lua_State *L, gint pos)
  63. {
  64. void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_compress}");
  65. luaL_argcheck (L, ud != NULL, pos, "'zstd_compress' expected");
  66. return ud ? *(ZSTD_CStream **)ud : NULL;
  67. }
  68. static ZSTD_DStream *
  69. lua_check_zstd_decompress_ctx (lua_State *L, gint pos)
  70. {
  71. void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_decompress}");
  72. luaL_argcheck (L, ud != NULL, pos, "'zstd_decompress' expected");
  73. return ud ? *(ZSTD_DStream **)ud : NULL;
  74. }
  75. int
  76. lua_zstd_push_error (lua_State *L, int err)
  77. {
  78. lua_pushnil (L);
  79. lua_pushfstring (L, "zstd error %d (%s)", err, ZSTD_getErrorString (err));
  80. return 2;
  81. }
  82. gint
  83. lua_compress_zstd_compress (lua_State *L)
  84. {
  85. LUA_TRACE_POINT;
  86. struct rspamd_lua_text *t = NULL, *res;
  87. gsize sz, r;
  88. gint comp_level = 1;
  89. t = lua_check_text_or_string (L,1);
  90. if (t == NULL || t->start == NULL) {
  91. return luaL_error (L, "invalid arguments");
  92. }
  93. if (lua_type (L, 2) == LUA_TNUMBER) {
  94. comp_level = lua_tointeger (L, 2);
  95. }
  96. sz = ZSTD_compressBound (t->len);
  97. if (ZSTD_isError (sz)) {
  98. msg_err ("cannot compress data: %s", ZSTD_getErrorName (sz));
  99. lua_pushnil (L);
  100. return 1;
  101. }
  102. res = lua_newuserdata (L, sizeof (*res));
  103. res->start = g_malloc (sz);
  104. res->flags = RSPAMD_TEXT_FLAG_OWN;
  105. rspamd_lua_setclass (L, "rspamd{text}", -1);
  106. r = ZSTD_compress ((void *)res->start, sz, t->start, t->len, comp_level);
  107. if (ZSTD_isError (r)) {
  108. msg_err ("cannot compress data: %s", ZSTD_getErrorName (r));
  109. lua_pop (L, 1); /* Text will be freed here */
  110. lua_pushnil (L);
  111. return 1;
  112. }
  113. res->len = r;
  114. return 1;
  115. }
  116. gint
  117. lua_compress_zstd_decompress (lua_State *L)
  118. {
  119. LUA_TRACE_POINT;
  120. struct rspamd_lua_text *t = NULL, *res;
  121. gsize outlen, r;
  122. ZSTD_DStream *zstream;
  123. ZSTD_inBuffer zin;
  124. ZSTD_outBuffer zout;
  125. gchar *out;
  126. t = lua_check_text_or_string (L,1);
  127. if (t == NULL || t->start == NULL) {
  128. return luaL_error (L, "invalid arguments");
  129. }
  130. zstream = ZSTD_createDStream ();
  131. ZSTD_initDStream (zstream);
  132. zin.pos = 0;
  133. zin.src = t->start;
  134. zin.size = t->len;
  135. if ((outlen = ZSTD_getDecompressedSize (zin.src, zin.size)) == 0) {
  136. outlen = ZSTD_DStreamOutSize ();
  137. }
  138. out = g_malloc (outlen);
  139. zout.dst = out;
  140. zout.pos = 0;
  141. zout.size = outlen;
  142. while (zin.pos < zin.size) {
  143. r = ZSTD_decompressStream (zstream, &zout, &zin);
  144. if (ZSTD_isError (r)) {
  145. msg_err ("cannot decompress data: %s", ZSTD_getErrorName (r));
  146. ZSTD_freeDStream (zstream);
  147. g_free (out);
  148. lua_pushstring (L, ZSTD_getErrorName (r));
  149. lua_pushnil (L);
  150. return 2;
  151. }
  152. if (zin.pos < zin.size && zout.pos == zout.size) {
  153. /* We need to extend output buffer */
  154. zout.size = zout.size * 2;
  155. out = g_realloc (zout.dst, zout.size);
  156. zout.dst = out;
  157. }
  158. }
  159. ZSTD_freeDStream (zstream);
  160. lua_pushnil (L); /* Error */
  161. res = lua_newuserdata (L, sizeof (*res));
  162. res->start = out;
  163. res->flags = RSPAMD_TEXT_FLAG_OWN;
  164. rspamd_lua_setclass (L, "rspamd{text}", -1);
  165. res->len = zout.pos;
  166. return 2;
  167. }
  168. gint
  169. lua_compress_zlib_decompress (lua_State *L, bool is_gzip)
  170. {
  171. LUA_TRACE_POINT;
  172. struct rspamd_lua_text *t = NULL, *res;
  173. gsize sz;
  174. z_stream strm;
  175. gint rc;
  176. guchar *p;
  177. gsize remain;
  178. gssize size_limit = -1;
  179. int windowBits = is_gzip ? (MAX_WBITS + 16) : (MAX_WBITS);
  180. t = lua_check_text_or_string (L,1);
  181. if (t == NULL || t->start == NULL) {
  182. return luaL_error (L, "invalid arguments");
  183. }
  184. if (lua_type (L, 2) == LUA_TNUMBER) {
  185. size_limit = lua_tointeger (L, 2);
  186. if (size_limit <= 0) {
  187. return luaL_error (L, "invalid arguments (size_limit)");
  188. }
  189. sz = MIN (t->len * 2, size_limit);
  190. }
  191. else {
  192. sz = t->len * 2;
  193. }
  194. memset (&strm, 0, sizeof (strm));
  195. /* windowBits +16 to decode gzip, zlib 1.2.0.4+ */
  196. /* Here are dragons to distinguish between raw deflate and zlib */
  197. if (windowBits == MAX_WBITS && t->len > 0) {
  198. if ((int)(unsigned char)((t->start[0] << 4)) != 0x80) {
  199. /* Assume raw deflate */
  200. windowBits = -windowBits;
  201. }
  202. }
  203. rc = inflateInit2 (&strm, windowBits);
  204. if (rc != Z_OK) {
  205. return luaL_error (L, "cannot init zlib");
  206. }
  207. strm.avail_in = t->len;
  208. strm.next_in = (guchar *)t->start;
  209. res = lua_newuserdata (L, sizeof (*res));
  210. res->start = g_malloc (sz);
  211. res->flags = RSPAMD_TEXT_FLAG_OWN;
  212. rspamd_lua_setclass (L, "rspamd{text}", -1);
  213. p = (guchar *)res->start;
  214. remain = sz;
  215. while (strm.avail_in != 0) {
  216. strm.avail_out = remain;
  217. strm.next_out = p;
  218. rc = inflate (&strm, Z_NO_FLUSH);
  219. if (rc != Z_OK && rc != Z_BUF_ERROR) {
  220. if (rc == Z_STREAM_END) {
  221. break;
  222. }
  223. else {
  224. msg_err ("cannot decompress data: %s (last error: %s)",
  225. zError (rc), strm.msg);
  226. lua_pop (L, 1); /* Text will be freed here */
  227. lua_pushnil (L);
  228. inflateEnd (&strm);
  229. return 1;
  230. }
  231. }
  232. res->len = strm.total_out;
  233. if (strm.avail_out == 0 && strm.avail_in != 0) {
  234. if (size_limit > 0 || res->len >= G_MAXUINT32 / 2) {
  235. if (res->len > size_limit || res->len >= G_MAXUINT32 / 2) {
  236. lua_pop (L, 1); /* Text will be freed here */
  237. lua_pushnil (L);
  238. inflateEnd (&strm);
  239. return 1;
  240. }
  241. }
  242. /* Need to allocate more */
  243. remain = res->len;
  244. res->start = g_realloc ((gpointer)res->start, res->len * 2);
  245. sz = res->len * 2;
  246. p = (guchar *)res->start + remain;
  247. remain = sz - remain;
  248. }
  249. }
  250. inflateEnd (&strm);
  251. res->len = strm.total_out;
  252. return 1;
  253. }
  254. gint
  255. lua_compress_zlib_compress (lua_State *L)
  256. {
  257. LUA_TRACE_POINT;
  258. struct rspamd_lua_text *t = NULL, *res;
  259. gsize sz;
  260. z_stream strm;
  261. gint rc, comp_level = Z_DEFAULT_COMPRESSION;
  262. guchar *p;
  263. gsize remain;
  264. t = lua_check_text_or_string (L,1);
  265. if (t == NULL || t->start == NULL) {
  266. return luaL_error (L, "invalid arguments");
  267. }
  268. if (lua_isnumber (L, 2)) {
  269. comp_level = lua_tointeger (L, 2);
  270. if (comp_level > Z_BEST_COMPRESSION || comp_level < Z_BEST_SPEED) {
  271. return luaL_error (L, "invalid arguments: compression level must be between %d and %d",
  272. Z_BEST_SPEED, Z_BEST_COMPRESSION);
  273. }
  274. }
  275. memset (&strm, 0, sizeof (strm));
  276. rc = deflateInit2 (&strm, comp_level, Z_DEFLATED,
  277. MAX_WBITS + 16, MAX_MEM_LEVEL - 1, Z_DEFAULT_STRATEGY);
  278. if (rc != Z_OK) {
  279. return luaL_error (L, "cannot init zlib: %s", zError (rc));
  280. }
  281. sz = deflateBound (&strm, t->len);
  282. strm.avail_in = t->len;
  283. strm.next_in = (guchar *) t->start;
  284. res = lua_newuserdata (L, sizeof (*res));
  285. res->start = g_malloc (sz);
  286. res->flags = RSPAMD_TEXT_FLAG_OWN;
  287. rspamd_lua_setclass (L, "rspamd{text}", -1);
  288. p = (guchar *) res->start;
  289. remain = sz;
  290. while (strm.avail_in != 0) {
  291. strm.avail_out = remain;
  292. strm.next_out = p;
  293. rc = deflate (&strm, Z_FINISH);
  294. if (rc != Z_OK && rc != Z_BUF_ERROR) {
  295. if (rc == Z_STREAM_END) {
  296. break;
  297. }
  298. else {
  299. msg_err ("cannot compress data: %s (last error: %s)",
  300. zError (rc), strm.msg);
  301. lua_pop (L, 1); /* Text will be freed here */
  302. lua_pushnil (L);
  303. deflateEnd (&strm);
  304. return 1;
  305. }
  306. }
  307. res->len = strm.total_out;
  308. if (strm.avail_out == 0 && strm.avail_in != 0) {
  309. /* Need to allocate more */
  310. remain = res->len;
  311. res->start = g_realloc ((gpointer) res->start, strm.avail_in + sz);
  312. sz = strm.avail_in + sz;
  313. p = (guchar *) res->start + remain;
  314. remain = sz - remain;
  315. }
  316. }
  317. deflateEnd (&strm);
  318. res->len = strm.total_out;
  319. return 1;
  320. }
  321. /* Stream API interface for Zstd: both compression and decompression */
  322. /* Operations allowed by zstd stream methods */
  323. static const char *const zstd_stream_op[] = {
  324. "continue",
  325. "flush",
  326. "end",
  327. NULL
  328. };
  329. static gint
  330. lua_zstd_compress_ctx (lua_State *L)
  331. {
  332. ZSTD_CCtx *ctx, **pctx;
  333. pctx = lua_newuserdata (L, sizeof (*pctx));
  334. ctx = ZSTD_createCCtx ();
  335. if (!ctx) {
  336. return luaL_error (L, "context create failed");
  337. }
  338. *pctx = ctx;
  339. rspamd_lua_setclass (L, "rspamd{zstd_compress}", -1);
  340. return 1;
  341. }
  342. static gint
  343. lua_zstd_compress_dtor (lua_State *L)
  344. {
  345. ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
  346. if (ctx) {
  347. ZSTD_freeCCtx (ctx);
  348. }
  349. return 0;
  350. }
  351. static gint
  352. lua_zstd_compress_reset (lua_State *L)
  353. {
  354. ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
  355. if (ctx) {
  356. ZSTD_CCtx_reset (ctx, ZSTD_reset_session_and_parameters);
  357. }
  358. else {
  359. return luaL_error (L, "invalid arguments");
  360. }
  361. return 0;
  362. }
  363. static gint
  364. lua_zstd_compress_stream (lua_State *L)
  365. {
  366. ZSTD_CStream *ctx = lua_check_zstd_compress_ctx (L, 1);
  367. struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
  368. int op = luaL_checkoption (L, 3, zstd_stream_op[0], zstd_stream_op);
  369. int err = 0;
  370. ZSTD_inBuffer inb;
  371. ZSTD_outBuffer onb;
  372. if (ctx && t) {
  373. gsize dlen = 0;
  374. inb.size = t->len;
  375. inb.pos = 0;
  376. inb.src = (const void*)t->start;
  377. onb.pos = 0;
  378. onb.size = ZSTD_CStreamInSize (); /* Initial guess */
  379. onb.dst = NULL;
  380. for (;;) {
  381. if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
  382. return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
  383. }
  384. dlen = onb.size;
  385. int res = ZSTD_compressStream2 (ctx, &onb, &inb, op);
  386. if (res == 0) {
  387. /* All done */
  388. break;
  389. }
  390. if ((err = ZSTD_getErrorCode (res))) {
  391. break;
  392. }
  393. onb.size *= 2;
  394. res += dlen; /* Hint returned by compression routine */
  395. /* Either double the buffer, or use the hint provided */
  396. if (onb.size < res) {
  397. onb.size = res;
  398. }
  399. }
  400. }
  401. else {
  402. return luaL_error (L, "invalid arguments");
  403. }
  404. if (err) {
  405. return lua_zstd_push_error (L, err);
  406. }
  407. lua_new_text (L, onb.dst, onb.pos, TRUE);
  408. return 1;
  409. }
  410. static gint
  411. lua_zstd_decompress_dtor (lua_State *L)
  412. {
  413. ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
  414. if (ctx) {
  415. ZSTD_freeDStream (ctx);
  416. }
  417. return 0;
  418. }
  419. static gint
  420. lua_zstd_decompress_ctx (lua_State *L)
  421. {
  422. ZSTD_DStream *ctx, **pctx;
  423. pctx = lua_newuserdata (L, sizeof (*pctx));
  424. ctx = ZSTD_createDStream ();
  425. if (!ctx) {
  426. return luaL_error (L, "context create failed");
  427. }
  428. *pctx = ctx;
  429. rspamd_lua_setclass (L, "rspamd{zstd_decompress}", -1);
  430. return 1;
  431. }
  432. static gint
  433. lua_zstd_decompress_stream (lua_State *L)
  434. {
  435. ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
  436. struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
  437. int err = 0;
  438. ZSTD_inBuffer inb;
  439. ZSTD_outBuffer onb;
  440. if (ctx && t) {
  441. gsize dlen = 0;
  442. if (t->len == 0) {
  443. return lua_zstd_push_error (L, ZSTD_error_init_missing);
  444. }
  445. inb.size = t->len;
  446. inb.pos = 0;
  447. inb.src = (const void*)t->start;
  448. onb.pos = 0;
  449. onb.size = ZSTD_DStreamInSize (); /* Initial guess */
  450. onb.dst = NULL;
  451. for (;;) {
  452. if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
  453. return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
  454. }
  455. dlen = onb.size;
  456. int res = ZSTD_decompressStream (ctx, &onb, &inb);
  457. if (res == 0) {
  458. /* All done */
  459. break;
  460. }
  461. if ((err = ZSTD_getErrorCode (res))) {
  462. break;
  463. }
  464. onb.size *= 2;
  465. res += dlen; /* Hint returned by compression routine */
  466. /* Either double the buffer, or use the hint provided */
  467. if (onb.size < res) {
  468. onb.size = res;
  469. }
  470. }
  471. }
  472. else {
  473. return luaL_error (L, "invalid arguments");
  474. }
  475. if (err) {
  476. return lua_zstd_push_error (L, err);
  477. }
  478. lua_new_text (L, onb.dst, onb.pos, TRUE);
  479. return 1;
  480. }
  481. static gint
  482. lua_load_zstd (lua_State * L)
  483. {
  484. lua_newtable (L);
  485. luaL_register (L, NULL, zstd_compress_lib_f);
  486. return 1;
  487. }
  488. void
  489. luaopen_compress (lua_State *L)
  490. {
  491. rspamd_lua_new_class (L, "rspamd{zstd_compress}", zstd_compress_lib_m);
  492. rspamd_lua_new_class (L, "rspamd{zstd_decompress}", zstd_decompress_lib_m);
  493. lua_pop (L, 2);
  494. rspamd_lua_add_preload (L, "rspamd_zstd", lua_load_zstd);
  495. }