aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_kann.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-28 12:43:29 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-28 12:43:29 +0100
commit6d9c4bed090e852b871d74443c8d34c4fa87a56e (patch)
tree78f810bd8fa22d1227b6f5e80fe0dbf86e040eaa /src/lua/lua_kann.c
parent913ac147bbc4e706095003fb8c16d24e2187a77f (diff)
downloadrspamd-6d9c4bed090e852b871d74443c8d34c4fa87a56e.tar.gz
rspamd-6d9c4bed090e852b871d74443c8d34c4fa87a56e.zip
[Project] Neural: Implement PCA on ANN forward
Diffstat (limited to 'src/lua/lua_kann.c')
-rw-r--r--src/lua/lua_kann.c50
1 files changed, 45 insertions, 5 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index db12e1f87..30bff538a 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -1205,21 +1205,48 @@ static int
lua_kann_apply1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
+ struct rspamd_lua_tensor *pca = NULL;
if (k) {
if (lua_istable (L, 2)) {
gsize vec_len = rspamd_lua_table_size (L, 2);
- float *vec = (float *) g_malloc (sizeof (float) * vec_len);
+ float *vec = (float *) g_malloc (sizeof (float) * vec_len),
+ *pca_out = NULL;
int i_out;
int n_in = kann_dim_in (k);
if (n_in <= 0) {
+ g_free (vec);
return luaL_error (L, "invalid inputs count: %d", n_in);
}
- if (n_in != vec_len) {
- return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
- (int) vec_len, n_in);
+ if (lua_isuserdata (L, 3)) {
+ pca = lua_check_tensor (L, 3);
+
+ if (pca) {
+ if (pca->ndims != 2) {
+ g_free (vec);
+ return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+ }
+
+ if (pca->dim[0] != n_in) {
+ g_free (vec);
+ return luaL_error (L, "invalid pca tensor: "
+ "matrix must have %d rows and it has %d rows instead",
+ n_in, pca->dim[0]);
+ }
+ }
+ else {
+ g_free (vec);
+ return luaL_error (L, "invalid params: pca matrix expected");
+ }
+ }
+ else {
+ if (n_in != vec_len) {
+ g_free (vec);
+ return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+ (int) vec_len, n_in);
+ }
}
for (gsize i = 0; i < vec_len; i++) {
@@ -1237,7 +1264,19 @@ lua_kann_apply1 (lua_State *L)
}
kann_set_batch_size (k, 1);
- kann_feed_bind (k, KANN_F_IN, 0, &vec);
+ if (pca) {
+ pca_out = g_malloc (sizeof (float) * n_in);
+
+ kad_sgemm_simple (0, 0, pca->dim[0], 1,
+ pca->dim[1], pca->data,
+ vec, pca_out);
+
+ kann_feed_bind (k, KANN_F_IN, 0, &pca_out);
+ }
+ else {
+ kann_feed_bind (k, KANN_F_IN, 0, &vec);
+ }
+
kad_eval_at (k->n, k->v, i_out);
gsize outlen = kad_len (k->v[i_out]);
@@ -1249,6 +1288,7 @@ lua_kann_apply1 (lua_State *L)
}
g_free (vec);
+ g_free (pca_out);
}
else if (lua_isuserdata (L, 2)) {
struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);