diff options
-rw-r--r-- | src/lua/lua_tensor.c | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index d14ec8831..09a10cabc 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (tensor, newindex); LUA_FUNCTION_DEF (tensor, len); LUA_FUNCTION_DEF (tensor, eugen); LUA_FUNCTION_DEF (tensor, mean); +LUA_FUNCTION_DEF (tensor, transpose); static luaL_reg rspamd_tensor_f[] = { LUA_INTERFACE_DEF (tensor, load), @@ -57,6 +58,7 @@ static luaL_reg rspamd_tensor_m[] = { {"__len", lua_tensor_len}, LUA_INTERFACE_DEF (tensor, eugen), LUA_INTERFACE_DEF (tensor, mean), + LUA_INTERFACE_DEF (tensor, transpose), {NULL, NULL}, }; @@ -625,6 +627,92 @@ lua_tensor_eugen (lua_State *L) return 1; } +static inline rspamd_tensor_num_t +mean_vec (rspamd_tensor_num_t *x, int n) +{ + rspamd_tensor_num_t s = 0; + rspamd_tensor_num_t c = 0; + + for (int i = 0; i < n; i ++) { + rspamd_tensor_num_t v = x[i]; + rspamd_tensor_num_t y = v - c; + rspamd_tensor_num_t t = s + y; + c = (t - s) - y; + s = t; + } + + return s / (rspamd_tensor_num_t)n; +} + +static gint +lua_tensor_mean (lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor (L, 1); + + if (t) { + if (t->ndims == 1) { + /* Mean of all elements in a vector */ + lua_pushnumber (L, mean_vec (t->data, t->dim[0])); + } + else { + /* Row-wise mean vector output */ + struct rspamd_lua_tensor *res; + + res = lua_newtensor (L, 1, &t->dim[0], false, true); + + for (int i = 0; i < t->dim[0]; i ++) { + res->data[i] = mean_vec (&t->data[i * t->dim[1]], t->dim[1]); + } + } + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_tensor_transpose (lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res; + int dims[2]; + + if (t) { + if (t->ndims == 1) { + /* Row to column */ + dims[0] = 1; + dims[1] = t->dim[0]; + res = lua_newtensor (L, 2, dims, false, true); + memcpy (res->data, t->data, t->dim[0] * sizeof (rspamd_tensor_num_t)); + } + else { + /* Cache friendly algorithm */ + struct rspamd_lua_tensor *res; + + dims[0] = t->dim[1]; + dims[1] = t->dim[0]; + res = lua_newtensor (L, 2, dims, false, true); + + static const int block = 32; + + for (int i = 0; i < t->dim[0]; i += block) { + for(int j = 0; j < t->dim[1]; ++j) { + for(int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) { + res->data[j * t->dim[0] + i + boff] = + t->data[(i + boff) * t->dim[1] + j]; + } + } + } + } + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 1; +} + static gint lua_load_tensor (lua_State * L) { |