]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Lua_tensor: Add eugen method
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 25 Aug 2020 13:01:23 +0000 (14:01 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 25 Aug 2020 14:42:12 +0000 (15:42 +0100)
src/lua/lua_tensor.c

index 4fcc7e2011a837e7edff7ff7c6c5583547696c2b..d14ec88313620ea22660937d426b9e7b0a416f80 100644 (file)
@@ -35,6 +35,8 @@ LUA_FUNCTION_DEF (tensor, tostring);
 LUA_FUNCTION_DEF (tensor, index);
 LUA_FUNCTION_DEF (tensor, newindex);
 LUA_FUNCTION_DEF (tensor, len);
+LUA_FUNCTION_DEF (tensor, eugen);
+LUA_FUNCTION_DEF (tensor, mean);
 
 static luaL_reg rspamd_tensor_f[] = {
                LUA_INTERFACE_DEF (tensor, load),
@@ -53,6 +55,8 @@ static luaL_reg rspamd_tensor_m[] = {
                {"__index", lua_tensor_index},
                {"__newindex", lua_tensor_newindex},
                {"__len", lua_tensor_len},
+               LUA_INTERFACE_DEF (tensor, eugen),
+               LUA_INTERFACE_DEF (tensor, mean),
                {NULL, NULL},
 };
 
@@ -465,10 +469,20 @@ lua_tensor_mul (lua_State *L)
                        return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",
                                        dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
                }
+               else if (shadow_dims[0] == 0) {
+                       /* Row * Column -> matrix */
+                       shadow_dims[0] = 1;
+                       shadow_dims[1] = 1;
+               }
 
                if (dims[0] == 0) {
                        /* Column */
                        dims[0] = 1;
+
+                       if (dims[1] == 0) {
+                               /* Column * row -> number */
+                               dims[1] = 1;
+                       }
                        res = lua_newtensor (L, 2, dims, true, true);
                }
                else if (dims[1] == 0) {
@@ -586,6 +600,31 @@ lua_tensor_len (lua_State *L)
        return nret;
 }
 
+static gint
+lua_tensor_eugen (lua_State *L)
+{
+       struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *eugen;
+
+       if (t) {
+               if (t->ndims != 2 || t->dim[0] != t->dim[1]) {
+                       return luaL_error (L, "expected square matrix NxN but got %dx%d",
+                                       t->dim[0], t->dim[1]);
+               }
+
+               eugen = lua_newtensor (L, 1, &t->dim[0], true, true);
+
+               if (!kad_ssyev_simple (t->dim[0], t->data, eugen->data)) {
+                       lua_pop (L, 1);
+                       return luaL_error (L, "kad_ssyev_simple failed (no blas?)");
+               }
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
+}
+
 static gint
 lua_load_tensor (lua_State * L)
 {