]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Lua_tensor: Add transpose and mean methods
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 25 Aug 2020 14:14:58 +0000 (15:14 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 25 Aug 2020 14:42:12 +0000 (15:42 +0100)
src/lua/lua_tensor.c

index d14ec88313620ea22660937d426b9e7b0a416f80..09a10cabcee9d65baf9b76784f29db695ce60cfd 100644 (file)
@@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (tensor, newindex);
 LUA_FUNCTION_DEF (tensor, len);
 LUA_FUNCTION_DEF (tensor, eugen);
 LUA_FUNCTION_DEF (tensor, mean);
+LUA_FUNCTION_DEF (tensor, transpose);
 
 static luaL_reg rspamd_tensor_f[] = {
                LUA_INTERFACE_DEF (tensor, load),
@@ -57,6 +58,7 @@ static luaL_reg rspamd_tensor_m[] = {
                {"__len", lua_tensor_len},
                LUA_INTERFACE_DEF (tensor, eugen),
                LUA_INTERFACE_DEF (tensor, mean),
+               LUA_INTERFACE_DEF (tensor, transpose),
                {NULL, NULL},
 };
 
@@ -625,6 +627,92 @@ lua_tensor_eugen (lua_State *L)
        return 1;
 }
 
+static inline rspamd_tensor_num_t
+mean_vec (rspamd_tensor_num_t *x, int n)
+{
+       rspamd_tensor_num_t s = 0;
+       rspamd_tensor_num_t c = 0;
+
+       for (int i = 0; i < n; i ++) {
+               rspamd_tensor_num_t v = x[i];
+               rspamd_tensor_num_t y = v - c;
+               rspamd_tensor_num_t t = s + y;
+               c = (t - s) - y;
+               s = t;
+       }
+
+       return s / (rspamd_tensor_num_t)n;
+}
+
+static gint
+lua_tensor_mean (lua_State *L)
+{
+       struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+       if (t) {
+               if (t->ndims == 1) {
+                       /* Mean of all elements in a vector */
+                       lua_pushnumber (L, mean_vec (t->data, t->dim[0]));
+               }
+               else {
+                       /* Row-wise mean vector output */
+                       struct rspamd_lua_tensor *res;
+
+                       res = lua_newtensor (L, 1, &t->dim[0], false, true);
+
+                       for (int i = 0; i < t->dim[0]; i ++) {
+                               res->data[i] = mean_vec (&t->data[i * t->dim[1]], t->dim[1]);
+                       }
+               }
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
+}
+
+static gint
+lua_tensor_transpose (lua_State *L)
+{
+       struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
+       int dims[2];
+
+       if (t) {
+               if (t->ndims == 1) {
+                       /* Row to column */
+                       dims[0] = 1;
+                       dims[1] = t->dim[0];
+                       res = lua_newtensor (L, 2, dims, false, true);
+                       memcpy (res->data, t->data, t->dim[0] * sizeof (rspamd_tensor_num_t));
+               }
+               else {
+                       /* Cache friendly algorithm */
+                       struct rspamd_lua_tensor *res;
+
+                       dims[0] = t->dim[1];
+                       dims[1] = t->dim[0];
+                       res = lua_newtensor (L, 2, dims, false, true);
+
+                       static const int block = 32;
+
+                       for (int i = 0; i < t->dim[0]; i += block) {
+                               for(int j = 0; j < t->dim[1]; ++j) {
+                                       for(int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) {
+                                               res->data[j * t->dim[0] + i + boff] =
+                                                               t->data[(i + boff) * t->dim[1] + j];
+                                       }
+                               }
+                       }
+               }
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
+}
+
 static gint
 lua_load_tensor (lua_State * L)
 {