#include "tokenizers/tokenizers.h"
#include "libserver/url.h"
#include "unix-std.h"
+#include "contrib/zstd/zstd.h"
#include <math.h>
#include <glob.h>
*/
LUA_FUNCTION_DEF (util, random_hex);
+/***
+ * @function util.zstd_compress(data)
+ * Compresses input using zstd compression
+ *
+ * @param {string/rspamd_text} data input data
+ * @return {rspamd_text} compressed data
+ */
+LUA_FUNCTION_DEF (util, zstd_compress);
+
+/***
+ * @function util.zstd_decompress(data)
+ * Decompresses input using zstd algorithm
+ *
+ * @param {string/rspamd_text} data compressed data
+ * @return {error,rspamd_text} pair of error + decompressed text
+ */
+LUA_FUNCTION_DEF (util, zstd_decompress);
+
/***
* @function util.pack(fmt, ...)
*
LUA_INTERFACE_DEF (util, create_file),
LUA_INTERFACE_DEF (util, close_file),
LUA_INTERFACE_DEF (util, random_hex),
+ LUA_INTERFACE_DEF (util, zstd_compress),
+ LUA_INTERFACE_DEF (util, zstd_decompress),
LUA_INTERFACE_DEF (util, pack),
LUA_INTERFACE_DEF (util, unpack),
LUA_INTERFACE_DEF (util, packsize),
return 1;
}
+static gint
+lua_util_zstd_compress (lua_State *L)
+{
+ struct rspamd_lua_text *t = NULL, *res;
+ gsize sz, r;
+
+ if (lua_type (L, 1) == LUA_TSTRING) {
+ t = g_alloca (sizeof (*t));
+ t->start = lua_tolstring (L, 1, &sz);
+ t->len = sz;
+ }
+ else {
+ t = lua_check_text (L, 1);
+ }
+
+ if (t == NULL || t->start == NULL) {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ sz = ZSTD_compressBound (t->len);
+
+ if (ZSTD_isError (sz)) {
+ msg_err ("cannot compress data: %s", ZSTD_getErrorName (sz));
+ lua_pushnil (L);
+
+ return 1;
+ }
+
+ res = lua_newuserdata (L, sizeof (*res));
+ res->start = g_malloc (sz);
+ res->flags = RSPAMD_TEXT_FLAG_OWN;
+ rspamd_lua_setclass (L, "rspamd{text}", -1);
+ r = ZSTD_compress ((void *)res->start, sz, t->start, t->len, 1);
+
+ if (ZSTD_isError (r)) {
+ msg_err ("cannot compress data: %s", ZSTD_getErrorName (r));
+ lua_pop (L, 1); /* Text will be freed here */
+ lua_pushnil (L);
+
+ return 1;
+ }
+
+ t->len = r;
+
+ return 1;
+}
+
+static gint
+lua_util_zstd_decompress (lua_State *L)
+{
+ struct rspamd_lua_text *t = NULL, *res;
+ gsize outlen, sz, r;
+ ZSTD_DStream *zstream;
+ ZSTD_inBuffer zin;
+ ZSTD_outBuffer zout;
+ gchar *out;
+
+ if (lua_type (L, 1) == LUA_TSTRING) {
+ t = g_alloca (sizeof (*t));
+ t->start = lua_tolstring (L, 1, &sz);
+ t->len = sz;
+ }
+ else {
+ t = lua_check_text (L, 1);
+ }
+
+ if (t == NULL || t->start == NULL) {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ zstream = ZSTD_createDStream ();
+ ZSTD_initDStream (zstream);
+
+ zin.pos = 0;
+ zin.src = t->start;
+ zin.size = t->len;
+
+ if ((outlen = ZSTD_getDecompressedSize (zin.src, zin.size)) == 0) {
+ outlen = ZSTD_DStreamOutSize ();
+ }
+
+ out = g_malloc (outlen);
+
+ zout.dst = out;
+ zout.pos = 0;
+ zout.size = outlen;
+
+ while (zin.pos < zin.size) {
+ r = ZSTD_decompressStream (zstream, &zout, &zin);
+
+ if (ZSTD_isError (r)) {
+ msg_err ("cannot decompress data: %s", ZSTD_getErrorName (r));
+ ZSTD_freeDStream (zstream);
+ g_free (out);
+ lua_pushstring (L, ZSTD_getErrorName (r));
+ lua_pushnil (L);
+
+ return 2;
+ }
+
+ if (zout.pos == zout.size) {
+ /* We need to extend output buffer */
+ zout.size = zout.size * 1.5 + 1.0;
+ out = g_realloc (zout.dst, zout.size);
+ zout.dst = out;
+ }
+ }
+
+ ZSTD_freeDStream (zstream);
+ lua_pushnil (L); /* Error */
+ res = lua_newuserdata (L, sizeof (*res));
+ res->start = out;
+ res->flags = RSPAMD_TEXT_FLAG_OWN;
+ rspamd_lua_setclass (L, "rspamd{text}", -1);
+ res->len = zout.pos;
+
+ return 2;
+}
+
/* Backport from Lua 5.3 */
/******************************************************************************