aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_kann.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-27 23:51:38 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-27 23:51:38 +0100
commit913ac147bbc4e706095003fb8c16d24e2187a77f (patch)
tree20d87d95ef07dcb0719e7f9abe2701b130069e37 /src/lua/lua_kann.c
parent68573e994099396d35181186f4e9d8e0cddbdd53 (diff)
downloadrspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.tar.gz
rspamd-913ac147bbc4e706095003fb8c16d24e2187a77f.zip
[Project] Neural: Fix PCA based learning
Diffstat (limited to 'src/lua/lua_kann.c')
-rw-r--r--src/lua/lua_kann.c88
1 files changed, 69 insertions, 19 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index 1827fe1ac..db12e1f87 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -1023,6 +1023,7 @@ static int
lua_kann_train1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
+ struct rspamd_lua_tensor *pca = NULL;
/* Default train params */
double lr = 0.001;
@@ -1055,8 +1056,8 @@ lua_kann_train1 (lua_State *L)
if (!rspamd_lua_parse_table_arguments (L, 4, &err,
RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
- "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
- &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+ "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
+ &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
n = luaL_error (L, "invalid params: %s",
err ? err->message : "unknown error");
g_error_free (err);
@@ -1065,36 +1066,83 @@ lua_kann_train1 (lua_State *L)
}
}
- float **x, **y;
+ if (pca) {
+ /* Check pca matrix validity */
+ if (pca->ndims != 2) {
+ return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+ }
+
+ if (pca->dim[0] != n_in) {
+ return luaL_error (L, "invalid pca tensor: "
+ "matrix must have %d rows and it has %d rows instead",
+ n_in, pca->dim[0]);
+ }
+ }
- /* Fill vectors */
+ float **x, **y, *tmp_row = NULL;
+
+ /* Fill vectors row by row */
x = (float **)g_malloc0 (sizeof (float *) * n);
y = (float **)g_malloc0 (sizeof (float *) * n);
+ if (pca) {
+ tmp_row = g_malloc (sizeof (float) * pca->dim[1]);
+ }
+
for (int s = 0; s < n; s ++) {
/* Inputs */
lua_rawgeti (L, 2, s + 1);
x[s] = (float *)g_malloc (sizeof (float) * n_in);
- if (rspamd_lua_table_size (L, -1) != n_in) {
- FREE_VEC (x, n);
- FREE_VEC (y, n);
+ if (pca == NULL) {
+ if (rspamd_lua_table_size (L, -1) != n_in) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
- lua_pop (L, 1);
- n = luaL_error (L, "invalid params at pos %d: "
- "bad input dimension %d; %d expected",
- s + 1,
- (int)rspamd_lua_table_size (L, -1),
- n_in);
+ n = luaL_error (L, "invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int) rspamd_lua_table_size (L, -1),
+ n_in);
+ lua_pop (L, 1);
- return n;
+ return n;
+ }
+
+ for (int i = 0; i < n_in; i++) {
+ lua_rawgeti (L, -1, i + 1);
+ x[s][i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
}
+ else {
+ if (rspamd_lua_table_size (L, -1) != pca->dim[1]) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+ g_free (tmp_row);
+
+ n = luaL_error (L, "(pca on) invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int) rspamd_lua_table_size (L, -1),
+ pca->dim[1]);
+ lua_pop (L, 1);
- for (int i = 0; i < n_in; i ++) {
- lua_rawgeti (L, -1, i + 1);
- x[s][i] = lua_tonumber (L, -1);
+ return n;
+ }
- lua_pop (L, 1);
+
+ for (int i = 0; i < pca->dim[1]; i++) {
+ lua_rawgeti (L, -1, i + 1);
+ tmp_row[i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
+
+ kad_sgemm_simple (0, 0, pca->dim[0], 1,
+ pca->dim[1], pca->data,
+ tmp_row, x[s]);
}
lua_pop (L, 1);
@@ -1104,9 +1152,9 @@ lua_kann_train1 (lua_State *L)
lua_rawgeti (L, 3, s + 1);
if (rspamd_lua_table_size (L, -1) != n_out) {
- lua_pop (L, 1);
FREE_VEC (x, n);
FREE_VEC (y, n);
+ g_free (tmp_row);
n = luaL_error (L, "invalid params at pos %d: "
"bad output dimension %d; "
@@ -1114,6 +1162,7 @@ lua_kann_train1 (lua_State *L)
s + 1,
(int)rspamd_lua_table_size (L, -1),
n_out);
+ lua_pop (L, 1);
return n;
}
@@ -1142,6 +1191,7 @@ lua_kann_train1 (lua_State *L)
FREE_VEC (x, n);
FREE_VEC (y, n);
+ g_free (tmp_row);
}
else {
return luaL_error (L, "invalid arguments: kann, inputs, outputs and"