From c9cd7aa05126042b80ea651699413b11deb9de8f Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 25 Aug 2020 14:01:23 +0100 Subject: [PATCH] [Minor] Lua_tensor: Add eugen method --- src/lua/lua_tensor.c | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index 4fcc7e201..d14ec8831 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -35,6 +35,8 @@ LUA_FUNCTION_DEF (tensor, tostring); LUA_FUNCTION_DEF (tensor, index); LUA_FUNCTION_DEF (tensor, newindex); LUA_FUNCTION_DEF (tensor, len); +LUA_FUNCTION_DEF (tensor, eugen); +LUA_FUNCTION_DEF (tensor, mean); static luaL_reg rspamd_tensor_f[] = { LUA_INTERFACE_DEF (tensor, load), @@ -53,6 +55,8 @@ static luaL_reg rspamd_tensor_m[] = { {"__index", lua_tensor_index}, {"__newindex", lua_tensor_newindex}, {"__len", lua_tensor_len}, + LUA_INTERFACE_DEF (tensor, eugen), + LUA_INTERFACE_DEF (tensor, mean), {NULL, NULL}, }; @@ -465,10 +469,20 @@ lua_tensor_mul (lua_State *L) return luaL_error (L, "incompatible dimensions %d x %d * %d x %d", dims[0], shadow_dims[1], shadow_dims[0], dims[1]); } + else if (shadow_dims[0] == 0) { + /* Row * Column -> matrix */ + shadow_dims[0] = 1; + shadow_dims[1] = 1; + } if (dims[0] == 0) { /* Column */ dims[0] = 1; + + if (dims[1] == 0) { + /* Column * row -> number */ + dims[1] = 1; + } res = lua_newtensor (L, 2, dims, true, true); } else if (dims[1] == 0) { @@ -586,6 +600,31 @@ lua_tensor_len (lua_State *L) return nret; } +static gint +lua_tensor_eugen (lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *eugen; + + if (t) { + if (t->ndims != 2 || t->dim[0] != t->dim[1]) { + return luaL_error (L, "expected square matrix NxN but got %dx%d", + t->dim[0], t->dim[1]); + } + + eugen = lua_newtensor (L, 1, &t->dim[0], true, true); + + if (!kad_ssyev_simple (t->dim[0], t->data, eugen->data)) { + lua_pop (L, 1); + return luaL_error (L, "kad_ssyev_simple failed (no blas?)"); + } + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 1; +} + static gint lua_load_tensor (lua_State * L) { -- 2.39.5