diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/lua/lua_tensor.c | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index e918188eb..06b7cdffe 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -39,12 +39,14 @@ LUA_FUNCTION_DEF (tensor, eigen); LUA_FUNCTION_DEF (tensor, mean); LUA_FUNCTION_DEF (tensor, transpose); LUA_FUNCTION_DEF (tensor, has_blas); +LUA_FUNCTION_DEF (tensor, scatter_matrix); static luaL_reg rspamd_tensor_f[] = { LUA_INTERFACE_DEF (tensor, load), LUA_INTERFACE_DEF (tensor, new), LUA_INTERFACE_DEF (tensor, fromtable), LUA_INTERFACE_DEF (tensor, has_blas), + LUA_INTERFACE_DEF (tensor, scatter_matrix), {NULL, NULL}, }; @@ -636,6 +638,7 @@ mean_vec (rspamd_tensor_num_t *x, int n) rspamd_tensor_num_t s = 0; rspamd_tensor_num_t c = 0; + /* https://en.wikipedia.org/wiki/Kahan_summation_algorithm */ for (int i = 0; i < n; i ++) { rspamd_tensor_num_t v = x[i]; rspamd_tensor_num_t y = v - c; @@ -729,6 +732,76 @@ lua_tensor_has_blas (lua_State *L) } static gint +lua_tensor_scatter_matrix (lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res; + int dims[2]; + + if (t) { + if (t->ndims != 2) { + return luaL_error (L, "matrix required"); + } + + /* X * X square matrix */ + dims[0] = t->dim[1]; + dims[1] = t->dim[1]; + res = lua_newtensor (L, 2, dims, true, true); + + /* Auxiliary vars */ + rspamd_tensor_num_t *means, /* means vector */ + *tmp_row, /* temp row for Kahan's algorithm */ + *tmp_square /* temp matrix for multiplications */; + means = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]); + tmp_row = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]); + tmp_square = g_malloc (sizeof (rspamd_tensor_num_t) * t->dim[1] * t->dim[1]); + + /* + * Column based means + * means will have s, tmp_row will have c + */ + for (int i = 0; i < t->dim[0]; i ++) { + /* Cycle by rows */ + for (int j = 0; j < t->dim[1]; j ++) { + rspamd_tensor_num_t v = t->data[i * t->dim[1] + j]; + rspamd_tensor_num_t y = v - tmp_row[j]; + rspamd_tensor_num_t st = means[j] + y; + tmp_row[j] = (st - means[j]) - y; + means[j] = st; + } + } + + for (int j = 0; j < t->dim[1]; j ++) { + means[j] /= t->dim[0]; + } + + for (int i = 0; i < t->dim[0]; i ++) { + /* Update for each sample */ + for (int j = 0; j < t->dim[1]; j ++) { + tmp_row[j] = t->data[i * t->dim[1] + j] - means[j]; + } + + memset (tmp_square, 0, t->dim[1] * t->dim[1] * sizeof (rspamd_tensor_num_t)); + kad_sgemm_simple (1, 0, t->dim[1], t->dim[1], 1, + tmp_row, tmp_row, tmp_square); + + for (int j = 0; j < t->dim[1]; j ++) { + kad_saxpy (t->dim[1], 1.0, &tmp_square[j * t->dim[1]], + &res->data[j * t->dim[1]]); + } + } + + g_free (tmp_row); + g_free (means); + g_free (tmp_square); + } + else { + return luaL_error (L, "tensor required"); + } + + return 1; +} + +static gint lua_load_tensor (lua_State * L) { lua_newtable (L); |