diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-27 23:51:38 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-27 23:51:38 +0100 |
commit | 913ac147bbc4e706095003fb8c16d24e2187a77f (patch) | |
tree | 20d87d95ef07dcb0719e7f9abe2701b130069e37 /src/lua/lua_kann.c | |
parent | 68573e994099396d35181186f4e9d8e0cddbdd53 (diff) | |
download | rspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.tar.gz rspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.zip |
[Project] Neural: Fix PCA based learning
Diffstat (limited to 'src/lua/lua_kann.c')
-rw-r--r-- | src/lua/lua_kann.c | 88 |
1 files changed, 69 insertions, 19 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index 1827fe1ac..db12e1f87 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -1023,6 +1023,7 @@ static int lua_kann_train1 (lua_State *L) { kann_t *k = lua_check_kann (L, 1); + struct rspamd_lua_tensor *pca = NULL; /* Default train params */ double lr = 0.001; @@ -1055,8 +1056,8 @@ lua_kann_train1 (lua_State *L) if (!rspamd_lua_parse_table_arguments (L, 4, &err, RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING, - "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F", - &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) { + "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}", + &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) { n = luaL_error (L, "invalid params: %s", err ? err->message : "unknown error"); g_error_free (err); @@ -1065,36 +1066,83 @@ lua_kann_train1 (lua_State *L) } } - float **x, **y; + if (pca) { + /* Check pca matrix validity */ + if (pca->ndims != 2) { + return luaL_error (L, "invalid pca tensor: matrix expected, got a row"); + } + + if (pca->dim[0] != n_in) { + return luaL_error (L, "invalid pca tensor: " + "matrix must have %d rows and it has %d rows instead", + n_in, pca->dim[0]); + } + } - /* Fill vectors */ + float **x, **y, *tmp_row = NULL; + + /* Fill vectors row by row */ x = (float **)g_malloc0 (sizeof (float *) * n); y = (float **)g_malloc0 (sizeof (float *) * n); + if (pca) { + tmp_row = g_malloc (sizeof (float) * pca->dim[1]); + } + for (int s = 0; s < n; s ++) { /* Inputs */ lua_rawgeti (L, 2, s + 1); x[s] = (float *)g_malloc (sizeof (float) * n_in); - if (rspamd_lua_table_size (L, -1) != n_in) { - FREE_VEC (x, n); - FREE_VEC (y, n); + if (pca == NULL) { + if (rspamd_lua_table_size (L, -1) != n_in) { + FREE_VEC (x, n); + FREE_VEC (y, n); - lua_pop (L, 1); - n = luaL_error (L, "invalid params at pos %d: " - "bad input dimension %d; %d expected", - s + 1, - (int)rspamd_lua_table_size (L, -1), - n_in); + n = luaL_error (L, "invalid params at pos %d: " + "bad input dimension %d; %d expected", + s + 1, + (int) rspamd_lua_table_size (L, -1), + n_in); + lua_pop (L, 1); - return n; + return n; + } + + for (int i = 0; i < n_in; i++) { + lua_rawgeti (L, -1, i + 1); + x[s][i] = lua_tonumber (L, -1); + + lua_pop (L, 1); + } } + else { + if (rspamd_lua_table_size (L, -1) != pca->dim[1]) { + FREE_VEC (x, n); + FREE_VEC (y, n); + g_free (tmp_row); + + n = luaL_error (L, "(pca on) invalid params at pos %d: " + "bad input dimension %d; %d expected", + s + 1, + (int) rspamd_lua_table_size (L, -1), + pca->dim[1]); + lua_pop (L, 1); - for (int i = 0; i < n_in; i ++) { - lua_rawgeti (L, -1, i + 1); - x[s][i] = lua_tonumber (L, -1); + return n; + } - lua_pop (L, 1); + + for (int i = 0; i < pca->dim[1]; i++) { + lua_rawgeti (L, -1, i + 1); + tmp_row[i] = lua_tonumber (L, -1); + + lua_pop (L, 1); + } + + kad_sgemm_simple (0, 0, pca->dim[0], 1, + pca->dim[1], pca->data, + tmp_row, x[s]); } lua_pop (L, 1); @@ -1104,9 +1152,9 @@ lua_kann_train1 (lua_State *L) lua_rawgeti (L, 3, s + 1); if (rspamd_lua_table_size (L, -1) != n_out) { - lua_pop (L, 1); FREE_VEC (x, n); FREE_VEC (y, n); + g_free (tmp_row); n = luaL_error (L, "invalid params at pos %d: " "bad output dimension %d; " @@ -1114,6 +1162,7 @@ lua_kann_train1 (lua_State *L) s + 1, (int)rspamd_lua_table_size (L, -1), n_out); + lua_pop (L, 1); return n; } @@ -1142,6 +1191,7 @@ lua_kann_train1 (lua_State *L) FREE_VEC (x, n); FREE_VEC (y, n); + g_free (tmp_row); } else { return luaL_error (L, "invalid arguments: kann, inputs, outputs and" |