]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Tensor: Move scatter matrix calculation to C
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 31 Aug 2020 14:34:44 +0000 (15:34 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 31 Aug 2020 14:34:44 +0000 (15:34 +0100)
src/lua/lua_tensor.c

index e918188ebecca727d12a2d4b8f363820338ef8d6..06b7cdffe5b1c5fbbb908bab26fb8b7fb10387ff 100644 (file)
@@ -39,12 +39,14 @@ LUA_FUNCTION_DEF (tensor, eigen);
 LUA_FUNCTION_DEF (tensor, mean);
 LUA_FUNCTION_DEF (tensor, transpose);
 LUA_FUNCTION_DEF (tensor, has_blas);
+LUA_FUNCTION_DEF (tensor, scatter_matrix);
 
 static luaL_reg rspamd_tensor_f[] = {
                LUA_INTERFACE_DEF (tensor, load),
                LUA_INTERFACE_DEF (tensor, new),
                LUA_INTERFACE_DEF (tensor, fromtable),
                LUA_INTERFACE_DEF (tensor, has_blas),
+               LUA_INTERFACE_DEF (tensor, scatter_matrix),
                {NULL, NULL},
 };
 
@@ -636,6 +638,7 @@ mean_vec (rspamd_tensor_num_t *x, int n)
        rspamd_tensor_num_t s = 0;
        rspamd_tensor_num_t c = 0;
 
+       /* https://en.wikipedia.org/wiki/Kahan_summation_algorithm */
        for (int i = 0; i < n; i ++) {
                rspamd_tensor_num_t v = x[i];
                rspamd_tensor_num_t y = v - c;
@@ -728,6 +731,76 @@ lua_tensor_has_blas (lua_State *L)
        return 1;
 }
 
+static gint
+lua_tensor_scatter_matrix (lua_State *L)
+{
+       struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
+       int dims[2];
+
+       if (t) {
+               if (t->ndims != 2) {
+                       return luaL_error (L, "matrix required");
+               }
+
+               /* X * X square matrix */
+               dims[0] = t->dim[1];
+               dims[1] = t->dim[1];
+               res = lua_newtensor (L, 2, dims, true, true);
+
+               /* Auxiliary vars */
+               rspamd_tensor_num_t *means, /* means vector */
+                       *tmp_row, /* temp row for Kahan's algorithm */
+                       *tmp_square /* temp matrix for multiplications */;
+               means = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]);
+               tmp_row = g_malloc0 (sizeof (rspamd_tensor_num_t) * t->dim[1]);
+               tmp_square = g_malloc (sizeof (rspamd_tensor_num_t) * t->dim[1] * t->dim[1]);
+
+               /*
+                * Column based means
+                * means will have s, tmp_row will have c
+                */
+               for (int i = 0; i < t->dim[0]; i ++) {
+                       /* Cycle by rows */
+                       for (int j = 0; j < t->dim[1]; j ++) {
+                               rspamd_tensor_num_t v = t->data[i * t->dim[1] + j];
+                               rspamd_tensor_num_t y = v - tmp_row[j];
+                               rspamd_tensor_num_t st = means[j] + y;
+                               tmp_row[j] = (st - means[j]) - y;
+                               means[j] = st;
+                       }
+               }
+
+               for (int j = 0; j < t->dim[1]; j ++) {
+                       means[j] /= t->dim[0];
+               }
+
+               for (int i = 0; i < t->dim[0]; i ++) {
+                       /* Update for each sample */
+                       for (int j = 0; j < t->dim[1]; j ++) {
+                               tmp_row[j] = t->data[i * t->dim[1] + j] - means[j];
+                       }
+
+                       memset (tmp_square, 0, t->dim[1] * t->dim[1] * sizeof (rspamd_tensor_num_t));
+                       kad_sgemm_simple (1, 0, t->dim[1], t->dim[1], 1,
+                                       tmp_row, tmp_row, tmp_square);
+
+                       for (int j = 0; j < t->dim[1]; j ++) {
+                               kad_saxpy (t->dim[1], 1.0, &tmp_square[j * t->dim[1]],
+                                               &res->data[j * t->dim[1]]);
+                       }
+               }
+
+               g_free (tmp_row);
+               g_free (means);
+               g_free (tmp_square);
+       }
+       else {
+               return luaL_error (L, "tensor required");
+       }
+
+       return 1;
+}
+
 static gint
 lua_load_tensor (lua_State * L)
 {