From 6d9c4bed090e852b871d74443c8d34c4fa87a56e Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 28 Aug 2020 12:43:29 +0100 Subject: [PATCH] [Project] Neural: Implement PCA on ANN forward --- src/lua/lua_kann.c | 50 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index db12e1f87..30bff538a 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -1205,21 +1205,48 @@ static int lua_kann_apply1 (lua_State *L) { kann_t *k = lua_check_kann (L, 1); + struct rspamd_lua_tensor *pca = NULL; if (k) { if (lua_istable (L, 2)) { gsize vec_len = rspamd_lua_table_size (L, 2); - float *vec = (float *) g_malloc (sizeof (float) * vec_len); + float *vec = (float *) g_malloc (sizeof (float) * vec_len), + *pca_out = NULL; int i_out; int n_in = kann_dim_in (k); if (n_in <= 0) { + g_free (vec); return luaL_error (L, "invalid inputs count: %d", n_in); } - if (n_in != vec_len) { - return luaL_error (L, "invalid params: bad input dimension %d; %d expected", - (int) vec_len, n_in); + if (lua_isuserdata (L, 3)) { + pca = lua_check_tensor (L, 3); + + if (pca) { + if (pca->ndims != 2) { + g_free (vec); + return luaL_error (L, "invalid pca tensor: matrix expected, got a row"); + } + + if (pca->dim[0] != n_in) { + g_free (vec); + return luaL_error (L, "invalid pca tensor: " + "matrix must have %d rows and it has %d rows instead", + n_in, pca->dim[0]); + } + } + else { + g_free (vec); + return luaL_error (L, "invalid params: pca matrix expected"); + } + } + else { + if (n_in != vec_len) { + g_free (vec); + return luaL_error (L, "invalid params: bad input dimension %d; %d expected", + (int) vec_len, n_in); + } } for (gsize i = 0; i < vec_len; i++) { @@ -1237,7 +1264,19 @@ lua_kann_apply1 (lua_State *L) } kann_set_batch_size (k, 1); - kann_feed_bind (k, KANN_F_IN, 0, &vec); + if (pca) { + pca_out = g_malloc (sizeof (float) * n_in); + + kad_sgemm_simple (0, 0, pca->dim[0], 1, + pca->dim[1], pca->data, + vec, pca_out); + + kann_feed_bind (k, KANN_F_IN, 0, &pca_out); + } + else { + kann_feed_bind (k, KANN_F_IN, 0, &vec); + } + kad_eval_at (k->n, k->v, i_out); gsize outlen = kad_len (k->v[i_out]); @@ -1249,6 +1288,7 @@ lua_kann_apply1 (lua_State *L) } g_free (vec); + g_free (pca_out); } else if (lua_isuserdata (L, 2)) { struct rspamd_lua_tensor *t = lua_check_tensor (L, 2); -- 2.39.5