You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_compress.c 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. /*-
  2. * Copyright 2021 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "unix-std.h"
  18. #include <zlib.h>
  19. #ifdef SYS_ZSTD
  20. #include "zstd.h"
  21. #include "zstd_errors.h"
  22. #else
  23. #include "contrib/zstd/zstd.h"
  24. #include "contrib/zstd/error_public.h"
  25. #endif
  26. /***
  27. * @module rspamd_compress
  28. * This module contains compression/decompression routines (zstd and zlib currently)
  29. */
  30. /***
  31. * @function zstd.compress_ctx()
  32. * Creates new compression ctx
  33. * @return {compress_ctx} new compress ctx
  34. */
  35. LUA_FUNCTION_DEF(zstd, compress_ctx);
  36. /***
  37. * @function zstd.compress_ctx()
  38. * Creates new compression ctx
  39. * @return {compress_ctx} new compress ctx
  40. */
  41. LUA_FUNCTION_DEF(zstd, decompress_ctx);
  42. LUA_FUNCTION_DEF(zstd_compress, stream);
  43. LUA_FUNCTION_DEF(zstd_compress, dtor);
  44. LUA_FUNCTION_DEF(zstd_decompress, stream);
  45. LUA_FUNCTION_DEF(zstd_decompress, dtor);
  46. static const struct luaL_reg zstd_compress_lib_f[] = {
  47. LUA_INTERFACE_DEF(zstd, compress_ctx),
  48. LUA_INTERFACE_DEF(zstd, decompress_ctx),
  49. {NULL, NULL}};
  50. static const struct luaL_reg zstd_compress_lib_m[] = {
  51. LUA_INTERFACE_DEF(zstd_compress, stream),
  52. {"__gc", lua_zstd_compress_dtor},
  53. {NULL, NULL}};
  54. static const struct luaL_reg zstd_decompress_lib_m[] = {
  55. LUA_INTERFACE_DEF(zstd_decompress, stream),
  56. {"__gc", lua_zstd_decompress_dtor},
  57. {NULL, NULL}};
  58. static ZSTD_CStream *
  59. lua_check_zstd_compress_ctx(lua_State *L, int pos)
  60. {
  61. void *ud = rspamd_lua_check_udata(L, pos, rspamd_zstd_compress_classname);
  62. luaL_argcheck(L, ud != NULL, pos, "'zstd_compress' expected");
  63. return ud ? *(ZSTD_CStream **) ud : NULL;
  64. }
  65. static ZSTD_DStream *
  66. lua_check_zstd_decompress_ctx(lua_State *L, int pos)
  67. {
  68. void *ud = rspamd_lua_check_udata(L, pos, rspamd_zstd_decompress_classname);
  69. luaL_argcheck(L, ud != NULL, pos, "'zstd_decompress' expected");
  70. return ud ? *(ZSTD_DStream **) ud : NULL;
  71. }
  72. int lua_zstd_push_error(lua_State *L, int err)
  73. {
  74. lua_pushnil(L);
  75. lua_pushfstring(L, "zstd error %d (%s)", err, ZSTD_getErrorString(err));
  76. return 2;
  77. }
  78. int lua_compress_zstd_compress(lua_State *L)
  79. {
  80. LUA_TRACE_POINT;
  81. struct rspamd_lua_text *t = NULL, *res;
  82. gsize sz, r;
  83. int comp_level = 1;
  84. t = lua_check_text_or_string(L, 1);
  85. if (t == NULL || t->start == NULL) {
  86. return luaL_error(L, "invalid arguments");
  87. }
  88. if (lua_type(L, 2) == LUA_TNUMBER) {
  89. comp_level = lua_tointeger(L, 2);
  90. }
  91. sz = ZSTD_compressBound(t->len);
  92. if (ZSTD_isError(sz)) {
  93. msg_err("cannot compress data: %s", ZSTD_getErrorName(sz));
  94. lua_pushnil(L);
  95. return 1;
  96. }
  97. res = lua_newuserdata(L, sizeof(*res));
  98. res->start = g_malloc(sz);
  99. res->flags = RSPAMD_TEXT_FLAG_OWN;
  100. rspamd_lua_setclass(L, rspamd_text_classname, -1);
  101. r = ZSTD_compress((void *) res->start, sz, t->start, t->len, comp_level);
  102. if (ZSTD_isError(r)) {
  103. msg_err("cannot compress data: %s", ZSTD_getErrorName(r));
  104. lua_pop(L, 1); /* Text will be freed here */
  105. lua_pushnil(L);
  106. return 1;
  107. }
  108. res->len = r;
  109. return 1;
  110. }
  111. int lua_compress_zstd_decompress(lua_State *L)
  112. {
  113. LUA_TRACE_POINT;
  114. struct rspamd_lua_text *t = NULL, *res;
  115. gsize outlen, r;
  116. ZSTD_DStream *zstream;
  117. ZSTD_inBuffer zin;
  118. ZSTD_outBuffer zout;
  119. char *out;
  120. t = lua_check_text_or_string(L, 1);
  121. if (t == NULL || t->start == NULL) {
  122. return luaL_error(L, "invalid arguments");
  123. }
  124. zstream = ZSTD_createDStream();
  125. ZSTD_initDStream(zstream);
  126. zin.pos = 0;
  127. zin.src = t->start;
  128. zin.size = t->len;
  129. if ((outlen = ZSTD_getDecompressedSize(zin.src, zin.size)) == 0) {
  130. outlen = ZSTD_DStreamOutSize();
  131. }
  132. out = g_malloc(outlen);
  133. zout.dst = out;
  134. zout.pos = 0;
  135. zout.size = outlen;
  136. while (zin.pos < zin.size) {
  137. r = ZSTD_decompressStream(zstream, &zout, &zin);
  138. if (ZSTD_isError(r)) {
  139. msg_err("cannot decompress data: %s", ZSTD_getErrorName(r));
  140. ZSTD_freeDStream(zstream);
  141. g_free(out);
  142. lua_pushstring(L, ZSTD_getErrorName(r));
  143. lua_pushnil(L);
  144. return 2;
  145. }
  146. if (zin.pos < zin.size && zout.pos == zout.size) {
  147. /* We need to extend output buffer */
  148. zout.size = zout.size * 2;
  149. out = g_realloc(zout.dst, zout.size);
  150. zout.dst = out;
  151. }
  152. }
  153. ZSTD_freeDStream(zstream);
  154. lua_pushnil(L); /* Error */
  155. res = lua_newuserdata(L, sizeof(*res));
  156. res->start = out;
  157. res->flags = RSPAMD_TEXT_FLAG_OWN;
  158. rspamd_lua_setclass(L, rspamd_text_classname, -1);
  159. res->len = zout.pos;
  160. return 2;
  161. }
  162. int lua_compress_zlib_decompress(lua_State *L, bool is_gzip)
  163. {
  164. LUA_TRACE_POINT;
  165. struct rspamd_lua_text *t = NULL, *res;
  166. gsize sz;
  167. z_stream strm;
  168. int rc;
  169. unsigned char *p;
  170. gsize remain;
  171. gssize size_limit = -1;
  172. int windowBits = is_gzip ? (MAX_WBITS + 16) : (MAX_WBITS);
  173. t = lua_check_text_or_string(L, 1);
  174. if (t == NULL || t->start == NULL) {
  175. return luaL_error(L, "invalid arguments");
  176. }
  177. if (lua_type(L, 2) == LUA_TNUMBER) {
  178. size_limit = lua_tointeger(L, 2);
  179. if (size_limit <= 0) {
  180. return luaL_error(L, "invalid arguments (size_limit)");
  181. }
  182. sz = MIN(t->len * 2, size_limit);
  183. }
  184. else {
  185. sz = t->len * 2;
  186. }
  187. memset(&strm, 0, sizeof(strm));
  188. /* windowBits +16 to decode gzip, zlib 1.2.0.4+ */
  189. /* Here are dragons to distinguish between raw deflate and zlib */
  190. if (windowBits == MAX_WBITS && t->len > 0) {
  191. if ((int) (unsigned char) ((t->start[0] << 4)) != 0x80) {
  192. /* Assume raw deflate */
  193. windowBits = -windowBits;
  194. }
  195. }
  196. rc = inflateInit2(&strm, windowBits);
  197. if (rc != Z_OK) {
  198. return luaL_error(L, "cannot init zlib");
  199. }
  200. strm.avail_in = t->len;
  201. strm.next_in = (unsigned char *) t->start;
  202. res = lua_newuserdata(L, sizeof(*res));
  203. res->start = g_malloc(sz);
  204. res->flags = RSPAMD_TEXT_FLAG_OWN;
  205. rspamd_lua_setclass(L, rspamd_text_classname, -1);
  206. p = (unsigned char *) res->start;
  207. remain = sz;
  208. while (strm.avail_in != 0) {
  209. strm.avail_out = remain;
  210. strm.next_out = p;
  211. rc = inflate(&strm, Z_NO_FLUSH);
  212. if (rc != Z_OK && rc != Z_BUF_ERROR) {
  213. if (rc == Z_STREAM_END) {
  214. break;
  215. }
  216. else {
  217. msg_err("cannot decompress data: %s (last error: %s)",
  218. zError(rc), strm.msg);
  219. lua_pop(L, 1); /* Text will be freed here */
  220. lua_pushnil(L);
  221. inflateEnd(&strm);
  222. return 1;
  223. }
  224. }
  225. res->len = strm.total_out;
  226. if (strm.avail_out == 0 && strm.avail_in != 0) {
  227. if (size_limit > 0 || res->len >= G_MAXUINT32 / 2) {
  228. if (res->len > size_limit || res->len >= G_MAXUINT32 / 2) {
  229. lua_pop(L, 1); /* Text will be freed here */
  230. lua_pushnil(L);
  231. inflateEnd(&strm);
  232. return 1;
  233. }
  234. }
  235. /* Need to allocate more */
  236. remain = res->len;
  237. res->start = g_realloc((gpointer) res->start, res->len * 2);
  238. sz = res->len * 2;
  239. p = (unsigned char *) res->start + remain;
  240. remain = sz - remain;
  241. }
  242. }
  243. inflateEnd(&strm);
  244. res->len = strm.total_out;
  245. return 1;
  246. }
  247. int lua_compress_zlib_compress(lua_State *L)
  248. {
  249. LUA_TRACE_POINT;
  250. struct rspamd_lua_text *t = NULL, *res;
  251. gsize sz;
  252. z_stream strm;
  253. int rc, comp_level = Z_DEFAULT_COMPRESSION;
  254. unsigned char *p;
  255. gsize remain;
  256. t = lua_check_text_or_string(L, 1);
  257. if (t == NULL || t->start == NULL) {
  258. return luaL_error(L, "invalid arguments");
  259. }
  260. if (lua_isnumber(L, 2)) {
  261. comp_level = lua_tointeger(L, 2);
  262. if (comp_level > Z_BEST_COMPRESSION || comp_level < Z_BEST_SPEED) {
  263. return luaL_error(L, "invalid arguments: compression level must be between %d and %d",
  264. Z_BEST_SPEED, Z_BEST_COMPRESSION);
  265. }
  266. }
  267. memset(&strm, 0, sizeof(strm));
  268. rc = deflateInit2(&strm, comp_level, Z_DEFLATED,
  269. MAX_WBITS + 16, MAX_MEM_LEVEL - 1, Z_DEFAULT_STRATEGY);
  270. if (rc != Z_OK) {
  271. return luaL_error(L, "cannot init zlib: %s", zError(rc));
  272. }
  273. sz = deflateBound(&strm, t->len);
  274. strm.avail_in = t->len;
  275. strm.next_in = (unsigned char *) t->start;
  276. res = lua_newuserdata(L, sizeof(*res));
  277. res->start = g_malloc(sz);
  278. res->flags = RSPAMD_TEXT_FLAG_OWN;
  279. rspamd_lua_setclass(L, rspamd_text_classname, -1);
  280. p = (unsigned char *) res->start;
  281. remain = sz;
  282. while (strm.avail_in != 0) {
  283. strm.avail_out = remain;
  284. strm.next_out = p;
  285. rc = deflate(&strm, Z_FINISH);
  286. if (rc != Z_OK && rc != Z_BUF_ERROR) {
  287. if (rc == Z_STREAM_END) {
  288. break;
  289. }
  290. else {
  291. msg_err("cannot compress data: %s (last error: %s)",
  292. zError(rc), strm.msg);
  293. lua_pop(L, 1); /* Text will be freed here */
  294. lua_pushnil(L);
  295. deflateEnd(&strm);
  296. return 1;
  297. }
  298. }
  299. res->len = strm.total_out;
  300. if (strm.avail_out == 0 && strm.avail_in != 0) {
  301. /* Need to allocate more */
  302. remain = res->len;
  303. res->start = g_realloc((gpointer) res->start, strm.avail_in + sz);
  304. sz = strm.avail_in + sz;
  305. p = (unsigned char *) res->start + remain;
  306. remain = sz - remain;
  307. }
  308. }
  309. deflateEnd(&strm);
  310. res->len = strm.total_out;
  311. return 1;
  312. }
  313. /* Stream API interface for Zstd: both compression and decompression */
  314. /* Operations allowed by zstd stream methods */
  315. static const char *const zstd_stream_op[] = {
  316. "continue",
  317. "flush",
  318. "end",
  319. NULL};
  320. static int
  321. lua_zstd_compress_ctx(lua_State *L)
  322. {
  323. ZSTD_CCtx *ctx, **pctx;
  324. pctx = lua_newuserdata(L, sizeof(*pctx));
  325. ctx = ZSTD_createCCtx();
  326. if (!ctx) {
  327. return luaL_error(L, "context create failed");
  328. }
  329. *pctx = ctx;
  330. rspamd_lua_setclass(L, rspamd_zstd_compress_classname, -1);
  331. return 1;
  332. }
  333. static int
  334. lua_zstd_compress_dtor(lua_State *L)
  335. {
  336. ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx(L, 1);
  337. if (ctx) {
  338. ZSTD_freeCCtx(ctx);
  339. }
  340. return 0;
  341. }
  342. static int
  343. lua_zstd_compress_reset(lua_State *L)
  344. {
  345. ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx(L, 1);
  346. if (ctx) {
  347. ZSTD_CCtx_reset(ctx, ZSTD_reset_session_and_parameters);
  348. }
  349. else {
  350. return luaL_error(L, "invalid arguments");
  351. }
  352. return 0;
  353. }
  354. static int
  355. lua_zstd_compress_stream(lua_State *L)
  356. {
  357. ZSTD_CStream *ctx = lua_check_zstd_compress_ctx(L, 1);
  358. struct rspamd_lua_text *t = lua_check_text_or_string(L, 2);
  359. int op = luaL_checkoption(L, 3, zstd_stream_op[0], zstd_stream_op);
  360. int err = 0;
  361. ZSTD_inBuffer inb;
  362. ZSTD_outBuffer onb;
  363. if (ctx && t) {
  364. gsize dlen = 0;
  365. inb.size = t->len;
  366. inb.pos = 0;
  367. inb.src = (const void *) t->start;
  368. onb.pos = 0;
  369. onb.size = ZSTD_CStreamInSize(); /* Initial guess */
  370. onb.dst = NULL;
  371. for (;;) {
  372. if ((onb.dst = g_realloc(onb.dst, onb.size)) == NULL) {
  373. return lua_zstd_push_error(L, ZSTD_error_memory_allocation);
  374. }
  375. dlen = onb.size;
  376. int res = ZSTD_compressStream2(ctx, &onb, &inb, op);
  377. if (res == 0) {
  378. /* All done */
  379. break;
  380. }
  381. if ((err = ZSTD_getErrorCode(res))) {
  382. break;
  383. }
  384. onb.size *= 2;
  385. res += dlen; /* Hint returned by compression routine */
  386. /* Either double the buffer, or use the hint provided */
  387. if (onb.size < res) {
  388. onb.size = res;
  389. }
  390. }
  391. }
  392. else {
  393. return luaL_error(L, "invalid arguments");
  394. }
  395. if (err) {
  396. return lua_zstd_push_error(L, err);
  397. }
  398. lua_new_text(L, onb.dst, onb.pos, TRUE);
  399. return 1;
  400. }
  401. static int
  402. lua_zstd_decompress_dtor(lua_State *L)
  403. {
  404. ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx(L, 1);
  405. if (ctx) {
  406. ZSTD_freeDStream(ctx);
  407. }
  408. return 0;
  409. }
  410. static int
  411. lua_zstd_decompress_ctx(lua_State *L)
  412. {
  413. ZSTD_DStream *ctx, **pctx;
  414. pctx = lua_newuserdata(L, sizeof(*pctx));
  415. ctx = ZSTD_createDStream();
  416. if (!ctx) {
  417. return luaL_error(L, "context create failed");
  418. }
  419. *pctx = ctx;
  420. rspamd_lua_setclass(L, rspamd_zstd_decompress_classname, -1);
  421. return 1;
  422. }
  423. static int
  424. lua_zstd_decompress_stream(lua_State *L)
  425. {
  426. ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx(L, 1);
  427. struct rspamd_lua_text *t = lua_check_text_or_string(L, 2);
  428. int err = 0;
  429. ZSTD_inBuffer inb;
  430. ZSTD_outBuffer onb;
  431. if (ctx && t) {
  432. gsize dlen = 0;
  433. if (t->len == 0) {
  434. return lua_zstd_push_error(L, ZSTD_error_init_missing);
  435. }
  436. inb.size = t->len;
  437. inb.pos = 0;
  438. inb.src = (const void *) t->start;
  439. onb.pos = 0;
  440. onb.size = ZSTD_DStreamInSize(); /* Initial guess */
  441. onb.dst = NULL;
  442. for (;;) {
  443. if ((onb.dst = g_realloc(onb.dst, onb.size)) == NULL) {
  444. return lua_zstd_push_error(L, ZSTD_error_memory_allocation);
  445. }
  446. dlen = onb.size;
  447. int res = ZSTD_decompressStream(ctx, &onb, &inb);
  448. if (res == 0) {
  449. /* All done */
  450. break;
  451. }
  452. if ((err = ZSTD_getErrorCode(res))) {
  453. break;
  454. }
  455. onb.size *= 2;
  456. res += dlen; /* Hint returned by compression routine */
  457. /* Either double the buffer, or use the hint provided */
  458. if (onb.size < res) {
  459. onb.size = res;
  460. }
  461. }
  462. }
  463. else {
  464. return luaL_error(L, "invalid arguments");
  465. }
  466. if (err) {
  467. return lua_zstd_push_error(L, err);
  468. }
  469. lua_new_text(L, onb.dst, onb.pos, TRUE);
  470. return 1;
  471. }
  472. static int
  473. lua_load_zstd(lua_State *L)
  474. {
  475. lua_newtable(L);
  476. luaL_register(L, NULL, zstd_compress_lib_f);
  477. return 1;
  478. }
  479. void luaopen_compress(lua_State *L)
  480. {
  481. rspamd_lua_new_class(L, rspamd_zstd_compress_classname, zstd_compress_lib_m);
  482. rspamd_lua_new_class(L, rspamd_zstd_decompress_classname, zstd_decompress_lib_m);
  483. lua_pop(L, 2);
  484. rspamd_lua_add_preload(L, "rspamd_zstd", lua_load_zstd);
  485. }