#include "lua_common.h"
#include "unix-std.h"
#include "contrib/zstd/zstd.h"
+#include "contrib/zstd/error_public.h"
#include <zlib.h>
/***
*/
LUA_FUNCTION_DEF (zstd, decompress_ctx);
+LUA_FUNCTION_DEF (zstd_compress, stream);
+LUA_FUNCTION_DEF (zstd_compress, dtor);
+
+LUA_FUNCTION_DEF (zstd_decompress, stream);
+LUA_FUNCTION_DEF (zstd_decompress, dtor);
+
+static const struct luaL_reg zstd_compress_lib_f[] = {
+ LUA_INTERFACE_DEF (zstd, compress_ctx),
+ LUA_INTERFACE_DEF (zstd, decompress_ctx),
+ {NULL, NULL}
+};
+
+static const struct luaL_reg zstd_compress_lib_m[] = {
+ LUA_INTERFACE_DEF (zstd_compress, stream),
+ {"__gc", lua_zstd_compress_dtor},
+ {NULL, NULL}
+};
+
+static const struct luaL_reg zstd_decompress_lib_m[] = {
+ LUA_INTERFACE_DEF (zstd_decompress, stream),
+ {"__gc", lua_zstd_decompress_dtor},
+ {NULL, NULL}
+};
+
+static ZSTD_CStream *
+lua_check_zstd_compress_ctx (lua_State *L, gint pos)
+{
+ void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_compress}");
+ luaL_argcheck (L, ud != NULL, pos, "'zstd_compress' expected");
+ return ud ? *(ZSTD_CStream **)ud : NULL;
+}
+
+static ZSTD_DStream *
+lua_check_zstd_decompress_ctx (lua_State *L, gint pos)
+{
+ void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_decompress}");
+ luaL_argcheck (L, ud != NULL, pos, "'zstd_decompress' expected");
+ return ud ? *(ZSTD_DStream **)ud : NULL;
+}
+
+int
+lua_zstd_push_error (lua_State *L, int err)
+{
+ lua_pushnil (L);
+ lua_pushfstring (L, "zstd error %d (%s)", err, ZSTD_getErrorString (err));
+
+ return 2;
+}
gint
lua_compress_zstd_compress (lua_State *L)
res->len = strm.total_out;
return 1;
+}
+
+/* Stream API interface for Zstd: both compression and decompression */
+
+/* Operations allowed by zstd stream methods */
+static const char *const zstd_stream_op[] = {
+ "continue",
+ "flush",
+ "end",
+ NULL
+};
+
+static gint
+lua_zstd_compress_ctx (lua_State *L)
+{
+ ZSTD_CCtx *ctx, **pctx;
+
+ pctx = lua_newuserdata (L, sizeof (*pctx));
+ ctx = ZSTD_createCCtx ();
+
+ if (!ctx) {
+ return luaL_error (L, "context create failed");
+ }
+
+ *pctx = ctx;
+ rspamd_lua_setclass (L, "rspamd{zstd_compress}", -1);
+ return 1;
+}
+
+static gint
+lua_zstd_compress_dtor (lua_State *L)
+{
+ ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
+
+ if (ctx) {
+ ZSTD_freeCCtx (ctx);
+ }
+
+ return 0;
+}
+
+static gint
+lua_zstd_compress_reset (lua_State *L)
+{
+ ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
+
+ if (ctx) {
+ ZSTD_CCtx_reset (ctx, ZSTD_reset_session_and_parameters);
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ return 0;
+}
+
+static gint
+lua_zstd_compress_stream (lua_State *L)
+{
+ ZSTD_CStream *ctx = lua_check_zstd_compress_ctx (L, 1);
+ struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
+ int op = luaL_checkoption (L, 3, zstd_stream_op[0], zstd_stream_op);
+ int err = 0;
+ ZSTD_inBuffer inb;
+ ZSTD_outBuffer onb;
+
+ if (ctx && t) {
+ gsize dlen = 0;
+
+ inb.size = t->len;
+ inb.pos = 0;
+ inb.src = (const void*)t->start;
+
+ onb.pos = 0;
+ onb.size = ZSTD_CStreamInSize (); /* Initial guess */
+ onb.dst = NULL;
+
+ for (;;) {
+ if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
+ return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
+ }
+
+ dlen = onb.size;
+
+ int res = ZSTD_compressStream2 (ctx, &onb, &inb, op);
+
+ if (res == 0) {
+ /* All done */
+ break;
+ }
+
+ if ((err = ZSTD_getErrorCode (res))) {
+ break;
+ }
+
+ onb.size *= 2;
+ res += dlen; /* Hint returned by compression routine */
+
+ /* Either double the buffer, or use the hint provided */
+ if (onb.size < res) {
+ onb.size = res;
+ }
+ }
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ if (err) {
+ return lua_zstd_push_error (L, err);
+ }
+
+ lua_new_text (L, onb.dst, onb.pos, TRUE);
+
+ return 1;
+}
+
+static gint
+lua_zstd_decompress_dtor (lua_State *L)
+{
+ ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
+
+ if (ctx) {
+ ZSTD_freeDStream (ctx);
+ }
+
+ return 0;
+}
+
+
+static gint
+lua_zstd_decompress_ctx (lua_State *L)
+{
+ ZSTD_DStream *ctx, **pctx;
+
+ pctx = lua_newuserdata (L, sizeof (*pctx));
+ ctx = ZSTD_createDStream ();
+
+ if (!ctx) {
+ return luaL_error (L, "context create failed");
+ }
+
+ *pctx = ctx;
+ rspamd_lua_setclass (L, "rspamd{zstd_decompress}", -1);
+ return 1;
+}
+
+static gint
+lua_zstd_decompress_stream (lua_State *L)
+{
+ ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
+ struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
+ int err = 0;
+ ZSTD_inBuffer inb;
+ ZSTD_outBuffer onb;
+
+ if (ctx && t) {
+ gsize dlen = 0;
+
+ inb.size = t->len;
+ inb.pos = 0;
+ inb.src = (const void*)t->start;
+
+ onb.pos = 0;
+ onb.size = ZSTD_DStreamInSize (); /* Initial guess */
+ onb.dst = NULL;
+
+ for (;;) {
+ if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
+ return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
+ }
+
+ dlen = onb.size;
+
+ int res = ZSTD_decompressStream (ctx, &onb, &inb);
+
+ if (res == 0) {
+ /* All done */
+ break;
+ }
+
+ if ((err = ZSTD_getErrorCode (res))) {
+ break;
+ }
+
+ onb.size *= 2;
+ res += dlen; /* Hint returned by compression routine */
+
+ /* Either double the buffer, or use the hint provided */
+ if (onb.size < res) {
+ onb.size = res;
+ }
+ }
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ if (err) {
+ return lua_zstd_push_error (L, err);
+ }
+
+ lua_new_text (L, onb.dst, onb.pos, TRUE);
+
+ return 1;
+}
+
+static gint
+lua_load_zstd (lua_State * L)
+{
+ lua_newtable (L);
+ luaL_register (L, NULL, zstd_compress_lib_f);
+
+ return 1;
+}
+
+void
+luaopen_compress (lua_State *L)
+{
+ rspamd_lua_new_class (L, "rspamd{zstd_compress}", zstd_compress_lib_m);
+ rspamd_lua_new_class (L, "rspamd{zstd_decompress}", zstd_decompress_lib_m);
+ lua_pop (L, 2);
+
+ rspamd_lua_add_preload (L, "rspamd_zstd", lua_load_zstd);
}
\ No newline at end of file