aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/lua/lua_tensor.c73
1 files changed, 73 insertions, 0 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index e918188eb..06b7cdffe 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -39,12 +39,14 @@ LUA_FUNCTION_DEF (tensor, eigen);
LUA_FUNCTION_DEF (tensor, mean);
LUA_FUNCTION_DEF (tensor, transpose);
LUA_FUNCTION_DEF (tensor, has_blas);
+LUA_FUNCTION_DEF (tensor, scatter_matrix);
static luaL_reg rspamd_tensor_f[] = {
LUA_INTERFACE_DEF (tensor, load),
LUA_INTERFACE_DEF (tensor, new),
LUA_INTERFACE_DEF (tensor, fromtable),
LUA_INTERFACE_DEF (tensor, has_blas),
+ LUA_INTERFACE_DEF (tensor, scatter_matrix),
{NULL, NULL},
};
@@ -636,6 +638,7 @@ mean_vec (rspamd_tensor_num_t *x, int n)
rspamd_tensor_num_t s = 0;
rspamd_tensor_num_t c = 0;
+ /* https://en.wikipedia.org/wiki/Kahan_summation_algorithm */
for (int i = 0; i < n; i ++) {
rspamd_tensor_num_t v = x[i];
rspamd_tensor_num_t y = v - c;
@@ -729,6 +732,76 @@ lua_tensor_has_blas (lua_State *L)
}
static gint
+lua_tensor_scatter_matrix (lua_State *L)
+{
+ struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
+ int dims[2];
+
+ if (t) {
+ if (t->ndims != 2) {
+ return luaL_error (L, "matrix required");
+ }
+
+ /* X * X square matrix */
+ dims[0] = t->dim[1];
+ dims[1] = t->dim[1];
+ res = lua_newtensor (L, 2, dims, true, true);
+
+ /* Auxiliary vars */
+ rspamd_tensor_num_t *means, /* means vector */
+ *tmp_row, /* temp row for Kahan's algorithm */
+ *tmp_square /* temp matrix for multiplications */;
+ means = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]);
+ tmp_row = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]);
+ tmp_square = g_malloc (sizeof (rspamd_tensor_num_t) * t->dim[1] * t->dim[1]);
+
+ /*
+ * Column based means
+ * means will have s, tmp_row will have c
+ */
+ for (int i = 0; i < t->dim[0]; i ++) {
+ /* Cycle by rows */
+ for (int j = 0; j < t->dim[1]; j ++) {
+ rspamd_tensor_num_t v = t->data[i * t->dim[1] + j];
+ rspamd_tensor_num_t y = v - tmp_row[j];
+ rspamd_tensor_num_t st = means[j] + y;
+ tmp_row[j] = (st - means[j]) - y;
+ means[j] = st;
+ }
+ }
+
+ for (int j = 0; j < t->dim[1]; j ++) {
+ means[j] /= t->dim[0];
+ }
+
+ for (int i = 0; i < t->dim[0]; i ++) {
+ /* Update for each sample */
+ for (int j = 0; j < t->dim[1]; j ++) {
+ tmp_row[j] = t->data[i * t->dim[1] + j] - means[j];
+ }
+
+ memset (tmp_square, 0, t->dim[1] * t->dim[1] * sizeof (rspamd_tensor_num_t));
+ kad_sgemm_simple (1, 0, t->dim[1], t->dim[1], 1,
+ tmp_row, tmp_row, tmp_square);
+
+ for (int j = 0; j < t->dim[1]; j ++) {
+ kad_saxpy (t->dim[1], 1.0, &tmp_square[j * t->dim[1]],
+ &res->data[j * t->dim[1]]);
+ }
+ }
+
+ g_free (tmp_row);
+ g_free (means);
+ g_free (tmp_square);
+ }
+ else {
+ return luaL_error (L, "tensor required");
+ }
+
+ return 1;
+}
+
+static gint
lua_load_tensor (lua_State * L)
{
lua_newtable (L);