]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add zstd streaming API
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Apr 2021 14:04:36 +0000 (15:04 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Apr 2021 14:04:36 +0000 (15:04 +0100)
src/lua/lua_common.c
src/lua/lua_compress.c
src/lua/lua_compress.h

index 482245ac979bb21638d1c6ba94bec2fec3d50318..06720c9f243bbb9e8746a83f0725a1d37bedff95 100644 (file)
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 #include "lua_common.h"
+#include "lua_compress.h"
 #include "lptree.h"
 #include "utlist.h"
 #include "unix-std.h"
@@ -981,6 +982,7 @@ rspamd_lua_init (bool wipe_mem)
        luaopen_spf (L);
        luaopen_tensor (L);
        luaopen_parsers (L);
+       luaopen_compress (L);
 #ifndef WITH_LUAJIT
        rspamd_lua_add_preload (L, "bit", luaopen_bit);
        lua_settop (L, 0);
index 8a2f77d04ba0e1cfb714b31d3bcd8eb3959dc4f8..8d2a7e70b095ae38335bb674315fb4ee51ef5b32 100644 (file)
@@ -17,6 +17,7 @@
 #include "lua_common.h"
 #include "unix-std.h"
 #include "contrib/zstd/zstd.h"
+#include "contrib/zstd/error_public.h"
 #include <zlib.h>
 
 /***
@@ -38,6 +39,54 @@ LUA_FUNCTION_DEF (zstd, compress_ctx);
  */
 LUA_FUNCTION_DEF (zstd, decompress_ctx);
 
+LUA_FUNCTION_DEF (zstd_compress, stream);
+LUA_FUNCTION_DEF (zstd_compress, dtor);
+
+LUA_FUNCTION_DEF (zstd_decompress, stream);
+LUA_FUNCTION_DEF (zstd_decompress, dtor);
+
+static const struct luaL_reg zstd_compress_lib_f[] = {
+               LUA_INTERFACE_DEF (zstd, compress_ctx),
+               LUA_INTERFACE_DEF (zstd, decompress_ctx),
+               {NULL, NULL}
+};
+
+static const struct luaL_reg zstd_compress_lib_m[] = {
+               LUA_INTERFACE_DEF (zstd_compress, stream),
+               {"__gc", lua_zstd_compress_dtor},
+               {NULL, NULL}
+};
+
+static const struct luaL_reg zstd_decompress_lib_m[] = {
+               LUA_INTERFACE_DEF (zstd_decompress, stream),
+               {"__gc", lua_zstd_decompress_dtor},
+               {NULL, NULL}
+};
+
+static ZSTD_CStream *
+lua_check_zstd_compress_ctx (lua_State *L, gint pos)
+{
+       void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_compress}");
+       luaL_argcheck (L, ud != NULL, pos, "'zstd_compress' expected");
+       return ud ? *(ZSTD_CStream **)ud : NULL;
+}
+
+static ZSTD_DStream *
+lua_check_zstd_decompress_ctx (lua_State *L, gint pos)
+{
+       void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_decompress}");
+       luaL_argcheck (L, ud != NULL, pos, "'zstd_decompress' expected");
+       return ud ? *(ZSTD_DStream **)ud : NULL;
+}
+
+int
+lua_zstd_push_error (lua_State *L, int err)
+{
+       lua_pushnil (L);
+       lua_pushfstring (L, "zstd error %d (%s)", err, ZSTD_getErrorString (err));
+
+       return 2;
+}
 
 gint
 lua_compress_zstd_compress (lua_State *L)
@@ -337,4 +386,228 @@ lua_compress_zlib_compress (lua_State *L)
        res->len = strm.total_out;
 
        return 1;
+}
+
+/* Stream API interface for Zstd: both compression and decompression */
+
+/* Operations allowed by zstd stream methods */
+static const char *const zstd_stream_op[] = {
+               "continue",
+               "flush",
+               "end",
+               NULL
+};
+
+static gint
+lua_zstd_compress_ctx (lua_State *L)
+{
+       ZSTD_CCtx *ctx, **pctx;
+
+       pctx = lua_newuserdata (L, sizeof (*pctx));
+       ctx = ZSTD_createCCtx ();
+
+       if (!ctx) {
+               return luaL_error (L, "context create failed");
+       }
+
+       *pctx = ctx;
+       rspamd_lua_setclass (L, "rspamd{zstd_compress}", -1);
+       return 1;
+}
+
+static gint
+lua_zstd_compress_dtor (lua_State *L)
+{
+       ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
+
+       if (ctx) {
+               ZSTD_freeCCtx (ctx);
+       }
+
+       return 0;
+}
+
+static gint
+lua_zstd_compress_reset (lua_State *L)
+{
+       ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
+
+       if (ctx) {
+               ZSTD_CCtx_reset (ctx, ZSTD_reset_session_and_parameters);
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 0;
+}
+
+static gint
+lua_zstd_compress_stream (lua_State *L)
+{
+       ZSTD_CStream *ctx = lua_check_zstd_compress_ctx (L, 1);
+       struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
+       int op = luaL_checkoption (L, 3, zstd_stream_op[0], zstd_stream_op);
+       int err = 0;
+       ZSTD_inBuffer inb;
+       ZSTD_outBuffer onb;
+
+       if (ctx && t) {
+               gsize dlen = 0;
+
+               inb.size = t->len;
+               inb.pos = 0;
+               inb.src = (const void*)t->start;
+
+               onb.pos = 0;
+               onb.size = ZSTD_CStreamInSize (); /* Initial guess */
+               onb.dst = NULL;
+
+               for (;;) {
+                       if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
+                               return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
+                       }
+
+                       dlen = onb.size;
+
+                       int res = ZSTD_compressStream2 (ctx, &onb, &inb, op);
+
+                       if (res == 0) {
+                               /* All done */
+                               break;
+                       }
+
+                       if ((err = ZSTD_getErrorCode (res))) {
+                               break;
+                       }
+
+                       onb.size *= 2;
+                       res += dlen; /* Hint returned by compression routine */
+
+                       /* Either double the buffer, or use the hint provided */
+                       if (onb.size < res) {
+                               onb.size = res;
+                       }
+               }
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       if (err) {
+               return lua_zstd_push_error (L, err);
+       }
+
+       lua_new_text (L, onb.dst, onb.pos, TRUE);
+
+       return 1;
+}
+
+static gint
+lua_zstd_decompress_dtor (lua_State *L)
+{
+       ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
+
+       if (ctx) {
+               ZSTD_freeDStream (ctx);
+       }
+
+       return 0;
+}
+
+
+static gint
+lua_zstd_decompress_ctx (lua_State *L)
+{
+       ZSTD_DStream *ctx, **pctx;
+
+       pctx = lua_newuserdata (L, sizeof (*pctx));
+       ctx = ZSTD_createDStream ();
+
+       if (!ctx) {
+               return luaL_error (L, "context create failed");
+       }
+
+       *pctx = ctx;
+       rspamd_lua_setclass (L, "rspamd{zstd_decompress}", -1);
+       return 1;
+}
+
+static gint
+lua_zstd_decompress_stream (lua_State *L)
+{
+       ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
+       struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
+       int err = 0;
+       ZSTD_inBuffer inb;
+       ZSTD_outBuffer onb;
+
+       if (ctx && t) {
+               gsize dlen = 0;
+
+               inb.size = t->len;
+               inb.pos = 0;
+               inb.src = (const void*)t->start;
+
+               onb.pos = 0;
+               onb.size = ZSTD_DStreamInSize (); /* Initial guess */
+               onb.dst = NULL;
+
+               for (;;) {
+                       if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
+                               return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
+                       }
+
+                       dlen = onb.size;
+
+                       int res = ZSTD_decompressStream (ctx, &onb, &inb);
+
+                       if (res == 0) {
+                               /* All done */
+                               break;
+                       }
+
+                       if ((err = ZSTD_getErrorCode (res))) {
+                               break;
+                       }
+
+                       onb.size *= 2;
+                       res += dlen; /* Hint returned by compression routine */
+
+                       /* Either double the buffer, or use the hint provided */
+                       if (onb.size < res) {
+                               onb.size = res;
+                       }
+               }
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       if (err) {
+               return lua_zstd_push_error (L, err);
+       }
+
+       lua_new_text (L, onb.dst, onb.pos, TRUE);
+
+       return 1;
+}
+
+static gint
+lua_load_zstd (lua_State * L)
+{
+       lua_newtable (L);
+       luaL_register (L, NULL, zstd_compress_lib_f);
+
+       return 1;
+}
+
+void
+luaopen_compress (lua_State *L)
+{
+       rspamd_lua_new_class (L, "rspamd{zstd_compress}", zstd_compress_lib_m);
+       rspamd_lua_new_class (L, "rspamd{zstd_decompress}", zstd_decompress_lib_m);
+       lua_pop (L, 2);
+
+       rspamd_lua_add_preload (L, "rspamd_zstd", lua_load_zstd);
 }
\ No newline at end of file
index 0d4ee13f563480f057db8495b3a6f476cb1748e8..7ac8d1a66eb67b0faa42985325cfee3714d8f3ba 100644 (file)
@@ -28,6 +28,8 @@ gint lua_compress_zstd_decompress (lua_State *L);
 gint lua_compress_zlib_compress (lua_State *L);
 gint lua_compress_zlib_decompress (lua_State *L, bool is_gzip);
 
+void luaopen_compress (lua_State *L);
+
 #ifdef  __cplusplus
 }
 #endif