]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Implement PCA on ANN forward
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 28 Aug 2020 11:43:29 +0000 (12:43 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 28 Aug 2020 11:43:29 +0000 (12:43 +0100)
src/lua/lua_kann.c

index db12e1f871f91151ea206d83e1ebe58e3727bacb..30bff538ad7f0671ec431033160c5b0acc3580fb 100644 (file)
@@ -1205,21 +1205,48 @@ static int
 lua_kann_apply1 (lua_State *L)
 {
        kann_t *k = lua_check_kann (L, 1);
+       struct rspamd_lua_tensor *pca = NULL;
 
        if (k) {
                if (lua_istable (L, 2)) {
                        gsize vec_len = rspamd_lua_table_size (L, 2);
-                       float *vec = (float *) g_malloc (sizeof (float) * vec_len);
+                       float *vec = (float *) g_malloc (sizeof (float) * vec_len),
+                               *pca_out = NULL;
                        int i_out;
                        int n_in = kann_dim_in (k);
 
                        if (n_in <= 0) {
+                               g_free (vec);
                                return luaL_error (L, "invalid inputs count: %d", n_in);
                        }
 
-                       if (n_in != vec_len) {
-                               return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
-                                               (int) vec_len, n_in);
+                       if (lua_isuserdata (L, 3)) {
+                               pca = lua_check_tensor (L, 3);
+
+                               if (pca) {
+                                       if (pca->ndims != 2) {
+                                               g_free (vec);
+                                               return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+                                       }
+
+                                       if (pca->dim[0] != n_in) {
+                                               g_free (vec);
+                                               return luaL_error (L, "invalid pca tensor: "
+                                                                                         "matrix must have %d rows and it has %d rows instead",
+                                                               n_in, pca->dim[0]);
+                                       }
+                               }
+                               else {
+                                       g_free (vec);
+                                       return luaL_error (L, "invalid params: pca matrix expected");
+                               }
+                       }
+                       else {
+                               if (n_in != vec_len) {
+                                       g_free (vec);
+                                       return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+                                                       (int) vec_len, n_in);
+                               }
                        }
 
                        for (gsize i = 0; i < vec_len; i++) {
@@ -1237,7 +1264,19 @@ lua_kann_apply1 (lua_State *L)
                        }
 
                        kann_set_batch_size (k, 1);
-                       kann_feed_bind (k, KANN_F_IN, 0, &vec);
+                       if (pca) {
+                               pca_out = g_malloc (sizeof (float) * n_in);
+
+                               kad_sgemm_simple (0, 0, pca->dim[0], 1,
+                                               pca->dim[1], pca->data,
+                                               vec, pca_out);
+
+                               kann_feed_bind (k, KANN_F_IN, 0, &pca_out);
+                       }
+                       else {
+                               kann_feed_bind (k, KANN_F_IN, 0, &vec);
+                       }
+
                        kad_eval_at (k->n, k->v, i_out);
 
                        gsize outlen = kad_len (k->v[i_out]);
@@ -1249,6 +1288,7 @@ lua_kann_apply1 (lua_State *L)
                        }
 
                        g_free (vec);
+                       g_free (pca_out);
                }
                else if (lua_isuserdata (L, 2)) {
                        struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);