]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add ZSTD compression to Lua API
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 10 Oct 2016 10:27:49 +0000 (11:27 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 10 Oct 2016 10:27:49 +0000 (11:27 +0100)
src/lua/lua_util.c

index 4ce7cbe64c61ce64b00e3eb53fe68a2523c22039..daa7cb57d957fe3c96886b0ab395bc02d4336f4b 100644 (file)
@@ -21,6 +21,7 @@
 #include "tokenizers/tokenizers.h"
 #include "libserver/url.h"
 #include "unix-std.h"
+#include "contrib/zstd/zstd.h"
 #include <math.h>
 #include <glob.h>
 
@@ -333,6 +334,24 @@ LUA_FUNCTION_DEF (util, close_file);
  */
 LUA_FUNCTION_DEF (util, random_hex);
 
+/***
+ * @function util.zstd_compress(data)
+ * Compresses input using zstd compression
+ *
+ * @param {string/rspamd_text} data input data
+ * @return {rspamd_text} compressed data
+ */
+LUA_FUNCTION_DEF (util, zstd_compress);
+
+/***
+ * @function util.zstd_decompress(data)
+ * Decompresses input using zstd algorithm
+ *
+ * @param {string/rspamd_text} data compressed data
+ * @return {error,rspamd_text} pair of error + decompressed text
+ */
+LUA_FUNCTION_DEF (util, zstd_decompress);
+
 /***
  * @function util.pack(fmt, ...)
  *
@@ -442,6 +461,8 @@ static const struct luaL_reg utillib_f[] = {
        LUA_INTERFACE_DEF (util, create_file),
        LUA_INTERFACE_DEF (util, close_file),
        LUA_INTERFACE_DEF (util, random_hex),
+       LUA_INTERFACE_DEF (util, zstd_compress),
+       LUA_INTERFACE_DEF (util, zstd_decompress),
        LUA_INTERFACE_DEF (util, pack),
        LUA_INTERFACE_DEF (util, unpack),
        LUA_INTERFACE_DEF (util, packsize),
@@ -1584,6 +1605,125 @@ lua_util_random_hex (lua_State *L)
        return 1;
 }
 
+static gint
+lua_util_zstd_compress (lua_State *L)
+{
+       struct rspamd_lua_text *t = NULL, *res;
+       gsize sz, r;
+
+       if (lua_type (L, 1) == LUA_TSTRING) {
+               t = g_alloca (sizeof (*t));
+               t->start = lua_tolstring (L, 1, &sz);
+               t->len = sz;
+       }
+       else {
+               t = lua_check_text (L, 1);
+       }
+
+       if (t == NULL || t->start == NULL) {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       sz = ZSTD_compressBound (t->len);
+
+       if (ZSTD_isError (sz)) {
+               msg_err ("cannot compress data: %s", ZSTD_getErrorName (sz));
+               lua_pushnil (L);
+
+               return 1;
+       }
+
+       res = lua_newuserdata (L, sizeof (*res));
+       res->start = g_malloc (sz);
+       res->flags = RSPAMD_TEXT_FLAG_OWN;
+       rspamd_lua_setclass (L, "rspamd{text}", -1);
+       r = ZSTD_compress ((void *)res->start, sz, t->start, t->len, 1);
+
+       if (ZSTD_isError (r)) {
+               msg_err ("cannot compress data: %s", ZSTD_getErrorName (r));
+               lua_pop (L, 1); /* Text will be freed here */
+               lua_pushnil (L);
+
+               return 1;
+       }
+
+       t->len = r;
+
+       return 1;
+}
+
+static gint
+lua_util_zstd_decompress (lua_State *L)
+{
+       struct rspamd_lua_text *t = NULL, *res;
+       gsize outlen, sz, r;
+       ZSTD_DStream *zstream;
+       ZSTD_inBuffer zin;
+       ZSTD_outBuffer zout;
+       gchar *out;
+
+       if (lua_type (L, 1) == LUA_TSTRING) {
+               t = g_alloca (sizeof (*t));
+               t->start = lua_tolstring (L, 1, &sz);
+               t->len = sz;
+       }
+       else {
+               t = lua_check_text (L, 1);
+       }
+
+       if (t == NULL || t->start == NULL) {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       zstream = ZSTD_createDStream ();
+       ZSTD_initDStream (zstream);
+
+       zin.pos = 0;
+       zin.src = t->start;
+       zin.size = t->len;
+
+       if ((outlen = ZSTD_getDecompressedSize (zin.src, zin.size)) == 0) {
+               outlen = ZSTD_DStreamOutSize ();
+       }
+
+       out = g_malloc (outlen);
+
+       zout.dst = out;
+       zout.pos = 0;
+       zout.size = outlen;
+
+       while (zin.pos < zin.size) {
+               r = ZSTD_decompressStream (zstream, &zout, &zin);
+
+               if (ZSTD_isError (r)) {
+                       msg_err ("cannot decompress data: %s", ZSTD_getErrorName (r));
+                       ZSTD_freeDStream (zstream);
+                       g_free (out);
+                       lua_pushstring (L, ZSTD_getErrorName (r));
+                       lua_pushnil (L);
+
+                       return 2;
+               }
+
+               if (zout.pos == zout.size) {
+                       /* We need to extend output buffer */
+                       zout.size = zout.size * 1.5 + 1.0;
+                       out = g_realloc (zout.dst, zout.size);
+                       zout.dst = out;
+               }
+       }
+
+       ZSTD_freeDStream (zstream);
+       lua_pushnil (L); /* Error */
+       res = lua_newuserdata (L, sizeof (*res));
+       res->start = out;
+       res->flags = RSPAMD_TEXT_FLAG_OWN;
+       rspamd_lua_setclass (L, "rspamd{text}", -1);
+       res->len = zout.pos;
+
+       return 2;
+}
+
 /* Backport from Lua 5.3 */
 
 /******************************************************************************