diff options
-rw-r--r-- | src/lua/lua_tensor.c | 56 |
1 files changed, 54 insertions, 2 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index 9b85779d7..91fcd763e 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -396,10 +396,62 @@ lua_tensor_mul (lua_State *L) static gint lua_tensor_load (lua_State *L) { - struct rspamd_lua_tensor *t = lua_check_tensor (L, 1); + const guchar *data; + gsize sz; - if (t) { + if (lua_type (L, 1) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text (L, 1); + + if (!t) { + return luaL_error (L, "invalid argument"); + } + + data = (const guchar *)t->start; + sz = t->len; + } + else { + data = (const guchar *)lua_tolstring (L, 1, &sz); + } + + if (sz >= sizeof (gint) * 4) { + int ndims, nelts, dims[2]; + memcpy (&ndims, data, sizeof (int)); + memcpy (&nelts, data + sizeof (int), sizeof (int)); + memcpy (dims, data + sizeof (int) * 2, sizeof (int) * 2); + + if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) { + if (ndims == 1) { + if (nelts == dims[0]) { + struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false); + memcpy (t->data, data + sizeof (int) * 4, nelts * + sizeof (rspamd_tensor_num_t)); + } + else { + return luaL_error (L, "invalid argument: bad dims: %d x %d != %d", + dims[0], 1, nelts); + } + } + else if (ndims == 2) { + if (nelts == dims[0] * dims[1]) { + struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false); + memcpy (t->data, data + sizeof (int) * 4, nelts * + sizeof (rspamd_tensor_num_t)); + } + else { + return luaL_error (L, "invalid argument: bad dims: %d x %d != %d", + dims[0], dims[1], nelts); + } + } + else { + return luaL_error (L, "invalid argument: bad ndims: %d", ndims); + } + } + else { + return luaL_error (L, "invalid size: %d, %d required, %d elts", (int)sz, + (int)(nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4), + nelts); + } } else { return luaL_error (L, "invalid arguments"); |