diff options
-rw-r--r-- | contrib/kann/kann.c | 12 | ||||
-rw-r--r-- | src/lua/lua_kann.c | 170 |
2 files changed, 178 insertions, 4 deletions
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c index 3fbf139cc..43227bdc6 100644 --- a/contrib/kann/kann.c +++ b/contrib/kann/kann.c @@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return kad_node_t *kann_layer_input(int n1) { kad_node_t *t; - t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN; + t = kad_feed(2, 1, n1); + t->ext_flag |= KANN_F_IN; return t; } @@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type) assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE); t = kann_layer_dense(t, n_out); truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH; + if (cost_type == KANN_C_MSE) { cost = kad_mse(t, truth); } else if (cost_type == KANN_C_CEB) { @@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type) t = kad_softmax(t); cost = kad_ce_multi(t, truth); } - t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST; + else { + assert (0); + } + + t->ext_flag |= KANN_F_OUT; + cost->ext_flag |= KANN_F_COST; + return cost; } diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index a1b31014d..609f05539 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -295,7 +295,7 @@ void luaopen_kann (lua_State *L) int fl = 0; \ if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \ else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \ - (n)->ext_flag = fl; \ + (n)->ext_flag |= fl; \ }while(0) /*** @@ -984,12 +984,168 @@ lua_kann_load (lua_State *L) return 1; } +struct rspamd_kann_train_cbdata { + lua_State *L; + kann_t *k; + gint cbref; +}; + +static void +lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud) +{ + struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud; + + if (cbd->cbref != -1) { + gint err_idx; + lua_State *L = cbd->L; + + lua_pushcfunction (L, &rspamd_lua_traceback); + err_idx = lua_gettop (L); + + lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref); + lua_pushinteger (L, iter); + lua_pushnumber (L, train_cost); + lua_pushnumber (L, val_cost); + + if (lua_pcall (L, 3, 0, err_idx) != 0) { + msg_err ("cannot run lua train callback: %s", + lua_tostring (L, -1)); + } + + lua_settop (L, err_idx - 1); + } +} + +#define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0) + static int lua_kann_train1 (lua_State *L) { kann_t *k = lua_check_kann (L, 1); - g_assert_not_reached (); /* TODO: implement */ + /* Default train params */ + double lr = 0.001; + gint64 mini_size = 64; + gint64 max_epoch = 25; + gint64 max_drop_streak = 10; + double frac_val = 0.1; + gint cbref = -1; + + if (k && lua_istable (L, 2) && lua_istable (L, 3)) { + int n = rspamd_lua_table_size (L, 2); + int n_in = kann_dim_in (k); + int n_out = kann_dim_out (k); + + if (n_in <= 0) { + return luaL_error (L, "invalid inputs count: %d", n_in); + } + + if (n_out <= 0) { + return luaL_error (L, "invalid outputs count: %d", n_in); + } + + if (n != rspamd_lua_table_size (L, 3) || n == 0) { + return luaL_error (L, "invalid dimensions: outputs size must be " + "equal to inputs and non zero"); + } + + if (lua_istable (L, 4)) { + GError *err = NULL; + + if (!rspamd_lua_parse_table_arguments (L, 4, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING, + "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F", + &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) { + n = luaL_error (L, "invalid params: %s", + err ? err->message : "unknown error"); + g_error_free (err); + + return n; + } + } + + float **x, **y; + + /* Fill vectors */ + x = (float **)g_malloc (sizeof (float *) * n); + y = (float **)g_malloc (sizeof (float *) * n); + + for (int s = 0; s < n; s ++) { + /* Inputs */ + lua_rawgeti (L, 2, s + 1); + x[s] = (float *)g_malloc (sizeof (float) * n_in); + + if (rspamd_lua_table_size (L, -1) != n_in) { + FREE_VEC (x, n); + FREE_VEC (y, n); + + n = luaL_error (L, "invalid params at pos %d: " + "bad input dimension %d; %d expected", + s + 1, + (int)rspamd_lua_table_size (L, -1), + n_in); + + return n; + } + + for (int i = 0; i < n_in; i ++) { + lua_rawgeti (L, -1, i + 1); + x[s][i] = lua_tonumber (L, -1); + + lua_pop (L, 1); + } + + lua_pop (L, 1); + + /* Outputs */ + y[s] = (float *)g_malloc (sizeof (float) * n_out); + lua_rawgeti (L, 3, s + 1); + + if (rspamd_lua_table_size (L, -1) != n_out) { + FREE_VEC (x, n); + FREE_VEC (y, n); + + n = luaL_error (L, "invalid params at pos %d: " + "bad output dimension %d; " + "%d expected", + s + 1, + (int)rspamd_lua_table_size (L, -1), + n_out); + + return n; + } + + for (int i = 0; i < n_out; i ++) { + lua_rawgeti (L, -1, i + 1); + y[s][i] = lua_tonumber (L, -1); + + lua_pop (L, 1); + } + + lua_pop (L, 1); + } + + struct rspamd_kann_train_cbdata cbd; + + cbd.cbref = cbref; + cbd.k = k; + cbd.L = L; + + int niters = kann_train_fnn1 (k, lr, + mini_size, max_epoch, max_drop_streak, + frac_val, n, x, y, lua_kann_train_cb, &cbd); + + lua_pushinteger (L, niters); + + FREE_VEC (x, n); + FREE_VEC (y, n); + } + else { + return luaL_error (L, "invalid arguments: kann, inputs, outputs and" + " optional params are expected"); + } + + return 1; } static int @@ -1001,6 +1157,16 @@ lua_kann_apply1 (lua_State *L) gsize vec_len = rspamd_lua_table_size (L, 2); float *vec = (float *)g_malloc (sizeof (float) * vec_len); int i_out; + int n_in = kann_dim_in (k); + + if (n_in <= 0) { + return luaL_error (L, "invalid inputs count: %d", n_in); + } + + if (n_in != vec_len) { + return luaL_error (L, "invalid params: bad input dimension %d; %d expected", + (int)vec_len, n_in); + } for (gsize i = 0; i < vec_len; i ++) { lua_rawgeti (L, 2, i + 1); |