aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_tensor.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-25 14:01:23 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-25 15:42:12 +0100
commitc9cd7aa05126042b80ea651699413b11deb9de8f (patch)
treea7f7b634b6cb40c8b288b957e699cf578219641c /src/lua/lua_tensor.c
parentbdb91fe0c6d7cba2fcf4e3a022c60c29b0082061 (diff)
downloadrspamd-c9cd7aa05126042b80ea651699413b11deb9de8f.tar.gz
rspamd-c9cd7aa05126042b80ea651699413b11deb9de8f.zip
[Minor] Lua_tensor: Add eugen method
Diffstat (limited to 'src/lua/lua_tensor.c')
-rw-r--r--src/lua/lua_tensor.c39
1 files changed, 39 insertions, 0 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index 4fcc7e201..d14ec8831 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -35,6 +35,8 @@ LUA_FUNCTION_DEF (tensor, tostring);
LUA_FUNCTION_DEF (tensor, index);
LUA_FUNCTION_DEF (tensor, newindex);
LUA_FUNCTION_DEF (tensor, len);
+LUA_FUNCTION_DEF (tensor, eugen);
+LUA_FUNCTION_DEF (tensor, mean);
static luaL_reg rspamd_tensor_f[] = {
LUA_INTERFACE_DEF (tensor, load),
@@ -53,6 +55,8 @@ static luaL_reg rspamd_tensor_m[] = {
{"__index", lua_tensor_index},
{"__newindex", lua_tensor_newindex},
{"__len", lua_tensor_len},
+ LUA_INTERFACE_DEF (tensor, eugen),
+ LUA_INTERFACE_DEF (tensor, mean),
{NULL, NULL},
};
@@ -465,10 +469,20 @@ lua_tensor_mul (lua_State *L)
return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",
dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
}
+ else if (shadow_dims[0] == 0) {
+ /* Row * Column -> matrix */
+ shadow_dims[0] = 1;
+ shadow_dims[1] = 1;
+ }
if (dims[0] == 0) {
/* Column */
dims[0] = 1;
+
+ if (dims[1] == 0) {
+ /* Column * row -> number */
+ dims[1] = 1;
+ }
res = lua_newtensor (L, 2, dims, true, true);
}
else if (dims[1] == 0) {
@@ -587,6 +601,31 @@ lua_tensor_len (lua_State *L)
}
static gint
+lua_tensor_eugen (lua_State *L)
+{
+ struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *eugen;
+
+ if (t) {
+ if (t->ndims != 2 || t->dim[0] != t->dim[1]) {
+ return luaL_error (L, "expected square matrix NxN but got %dx%d",
+ t->dim[0], t->dim[1]);
+ }
+
+ eugen = lua_newtensor (L, 1, &t->dim[0], true, true);
+
+ if (!kad_ssyev_simple (t->dim[0], t->data, eugen->data)) {
+ lua_pop (L, 1);
+ return luaL_error (L, "kad_ssyev_simple failed (no blas?)");
+ }
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ return 1;
+}
+
+static gint
lua_load_tensor (lua_State * L)
{
lua_newtable (L);