From 414c7b4ff70f4bbe934166709d29fc37389e20be Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 5 Aug 2020 16:05:40 +0100 Subject: [PATCH] [Minor] Add printing and fix multiplication --- src/lua/lua_tensor.c | 71 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 62 insertions(+), 9 deletions(-) diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index e8aebd180..21bdf9673 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -31,6 +31,7 @@ LUA_FUNCTION_DEF (tensor, new); LUA_FUNCTION_DEF (tensor, fromtable); LUA_FUNCTION_DEF (tensor, destroy); LUA_FUNCTION_DEF (tensor, mul); +LUA_FUNCTION_DEF (tensor, tostring); static luaL_reg rspamd_tensor_f[] = { LUA_INTERFACE_DEF (tensor, load), @@ -44,6 +45,8 @@ static luaL_reg rspamd_tensor_m[] = { {"__gc", lua_tensor_destroy}, {"__mul", lua_tensor_mul}, {"mul", lua_tensor_mul}, + {"__tostring", lua_tensor_tostring}, + {"tostring", lua_tensor_tostring}, {NULL, NULL}, }; @@ -114,12 +117,14 @@ lua_tensor_fromtable (lua_State *L) if (lua_isnumber (L, -1)) { lua_pop (L, 1); /* Input vector */ - gint dim = rspamd_lua_table_size (L, 1); + gint dims[2]; + dims[0] = 1; + dims[1] = rspamd_lua_table_size (L, 1); - struct rspamd_lua_tensor *res = lua_newtensor (L, 1, - &dim, false); + struct rspamd_lua_tensor *res = lua_newtensor (L, 2, + dims, false); - for (guint i = 0; i < dim; i ++) { + for (guint i = 0; i < dims[1]; i ++) { lua_rawgeti (L, 1, i + 1); res->data[i] = lua_tonumber (L, -1); lua_pop (L, 1); @@ -168,8 +173,8 @@ lua_tensor_fromtable (lua_State *L) } gint dims[2]; - dims[0] = ncols; - dims[1] = nrows; + dims[0] = nrows; + dims[1] = ncols; struct rspamd_lua_tensor *res = lua_newtensor (L, 2, dims, false); @@ -238,6 +243,47 @@ lua_tensor_save (lua_State *L) return 1; } +static gint +lua_tensor_tostring (lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor (L, 1); + + if (t) { + GString *out = g_string_sized_new (128); + + if (t->ndims == 1) { + /* Print as a vector */ + for (gint i = 0; i < t->dim[0]; i ++) { + rspamd_printf_gstring (out, "%.4f ", t->data[i]); + } + /* Trim last space */ + out->len --; + } + else { + for (gint i = 0; i < t->dim[0]; i ++) { + for (gint j = 0; j < t->dim[1]; j ++) { + rspamd_printf_gstring (out, "%.4f ", + t->data[i * t->dim[1] + j]); + } + /* Trim last space */ + out->len --; + rspamd_printf_gstring (out, "\n"); + } + /* Trim last ; */ + out->len --; + } + + lua_pushlstring (L, out->str, out->len); + + g_string_free (out, TRUE); + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 1; +} + /*** * @method tensor:mul(other, [transA, [transB]]) * Multiply two tensors (optionally transposed) and return a new tensor @@ -259,12 +305,19 @@ lua_tensor_mul (lua_State *L) } if (t1 && t2) { - gint dims[2]; + gint dims[2], shadow_dims[2]; dims[0] = transA ? t1->dim[1] : t1->dim[0]; + shadow_dims[0] = transB ? t2->dim[1] : t2->dim[0]; dims[1] = transB ? t2->dim[0] : t2->dim[1]; + shadow_dims[1] = transA ? t1->dim[0] : t1->dim[1]; + + if (shadow_dims[0] != shadow_dims[1]) { + return luaL_error (L, "incompatible dimensions %d x %d * %d x %d", + dims[0], shadow_dims[1], shadow_dims[0], dims[1]); + } - res = lua_newtensor (L, 2, dims, false); - kad_sgemm_simple (transA, transB, t1->dim[1], t2->dim[0], t1->dim[0], + res = lua_newtensor (L, 2, dims, true); + kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0], t1->data, t2->data, res->data); } else { -- 2.39.5