aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/lua/lua_tensor.c88
1 files changed, 88 insertions, 0 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index d14ec8831..09a10cabc 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -37,6 +37,7 @@ LUA_FUNCTION_DEF (tensor, newindex);
LUA_FUNCTION_DEF (tensor, len);
LUA_FUNCTION_DEF (tensor, eugen);
LUA_FUNCTION_DEF (tensor, mean);
+LUA_FUNCTION_DEF (tensor, transpose);
static luaL_reg rspamd_tensor_f[] = {
LUA_INTERFACE_DEF (tensor, load),
@@ -57,6 +58,7 @@ static luaL_reg rspamd_tensor_m[] = {
{"__len", lua_tensor_len},
LUA_INTERFACE_DEF (tensor, eugen),
LUA_INTERFACE_DEF (tensor, mean),
+ LUA_INTERFACE_DEF (tensor, transpose),
{NULL, NULL},
};
@@ -625,6 +627,92 @@ lua_tensor_eugen (lua_State *L)
return 1;
}
+static inline rspamd_tensor_num_t
+mean_vec (rspamd_tensor_num_t *x, int n)
+{
+ rspamd_tensor_num_t s = 0;
+ rspamd_tensor_num_t c = 0;
+
+ for (int i = 0; i < n; i ++) {
+ rspamd_tensor_num_t v = x[i];
+ rspamd_tensor_num_t y = v - c;
+ rspamd_tensor_num_t t = s + y;
+ c = (t - s) - y;
+ s = t;
+ }
+
+ return s / (rspamd_tensor_num_t)n;
+}
+
+static gint
+lua_tensor_mean (lua_State *L)
+{
+ struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+ if (t) {
+ if (t->ndims == 1) {
+ /* Mean of all elements in a vector */
+ lua_pushnumber (L, mean_vec (t->data, t->dim[0]));
+ }
+ else {
+ /* Row-wise mean vector output */
+ struct rspamd_lua_tensor *res;
+
+ res = lua_newtensor (L, 1, &t->dim[0], false, true);
+
+ for (int i = 0; i < t->dim[0]; i ++) {
+ res->data[i] = mean_vec (&t->data[i * t->dim[1]], t->dim[1]);
+ }
+ }
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ return 1;
+}
+
+static gint
+lua_tensor_transpose (lua_State *L)
+{
+ struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
+ int dims[2];
+
+ if (t) {
+ if (t->ndims == 1) {
+ /* Row to column */
+ dims[0] = 1;
+ dims[1] = t->dim[0];
+ res = lua_newtensor (L, 2, dims, false, true);
+ memcpy (res->data, t->data, t->dim[0] * sizeof (rspamd_tensor_num_t));
+ }
+ else {
+ /* Cache friendly algorithm */
+ struct rspamd_lua_tensor *res;
+
+ dims[0] = t->dim[1];
+ dims[1] = t->dim[0];
+ res = lua_newtensor (L, 2, dims, false, true);
+
+ static const int block = 32;
+
+ for (int i = 0; i < t->dim[0]; i += block) {
+ for(int j = 0; j < t->dim[1]; ++j) {
+ for(int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) {
+ res->data[j * t->dim[0] + i + boff] =
+ t->data[(i + boff) * t->dim[1] + j];
+ }
+ }
+ }
+ }
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ return 1;
+}
+
static gint
lua_load_tensor (lua_State * L)
{