diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-06-30 09:40:58 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-06-30 09:40:58 +0100 |
commit | 083e6ac5ce374e1e9759c7998dd04b9525333eb4 (patch) | |
tree | fd421968912f3627c501ed6d70ee1bdd4a39a9d2 | |
parent | 95edae6494dac4acf6ab19714a45339e515b8c49 (diff) | |
download | rspamd-083e6ac5ce374e1e9759c7998dd04b9525333eb4.tar.gz rspamd-083e6ac5ce374e1e9759c7998dd04b9525333eb4.zip |
[Project] Add simple forward propagation function
-rw-r--r-- | contrib/kann/kann.h | 5 | ||||
-rw-r--r-- | src/lua/lua_kann.c | 51 |
2 files changed, 48 insertions, 8 deletions
diff --git a/contrib/kann/kann.h b/contrib/kann/kann.h index 7ec748561..af0de5fba 100644 --- a/contrib/kann/kann.h +++ b/contrib/kann/kann.h @@ -220,7 +220,10 @@ kad_node_t *kann_layer_rnn2(int *offset, kad_node_t **par, kad_node_t *in, kad_n kad_node_t *kann_layer_gru2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag); /* operations on network with a single input node and a single output node */ -int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max_drop_streak, float frac_val, int n, float **_x, float **_y); +typedef void (*kann_train_cb)(int iter, float train_cost, float val_cost, void *ud); +int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, + int max_drop_streak, float frac_val, int n, + float **_x, float **_y, kann_train_cb cb, void *ud); float kann_cost_fnn1(kann_t *a, int n, float **x, float **y); const float *kann_apply1(kann_t *a, float *x); diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index 3d50cc587..a1b31014d 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -143,13 +143,13 @@ static luaL_reg rspamd_kann_new_f[] = { LUA_FUNCTION_DEF (kann, load); LUA_FUNCTION_DEF (kann, destroy); LUA_FUNCTION_DEF (kann, save); -LUA_FUNCTION_DEF (kann, train); -LUA_FUNCTION_DEF (kann, forward); +LUA_FUNCTION_DEF (kann, train1); +LUA_FUNCTION_DEF (kann, apply1); static luaL_reg rspamd_kann_m[] = { LUA_INTERFACE_DEF (kann, save), - LUA_INTERFACE_DEF (kann, train), - LUA_INTERFACE_DEF (kann, forward), + LUA_INTERFACE_DEF (kann, train1), + LUA_INTERFACE_DEF (kann, apply1), {"__gc", lua_kann_destroy}, {NULL, NULL}, }; @@ -985,7 +985,7 @@ lua_kann_load (lua_State *L) } static int -lua_kann_train (lua_State *L) +lua_kann_train1 (lua_State *L) { kann_t *k = lua_check_kann (L, 1); @@ -993,9 +993,46 @@ lua_kann_train (lua_State *L) } static int -lua_kann_forward (lua_State *L) +lua_kann_apply1 (lua_State *L) { kann_t *k = lua_check_kann (L, 1); - g_assert_not_reached (); /* TODO: implement */ + if (k && lua_istable (L, 2)) { + gsize vec_len = rspamd_lua_table_size (L, 2); + float *vec = (float *)g_malloc (sizeof (float) * vec_len); + int i_out; + + for (gsize i = 0; i < vec_len; i ++) { + lua_rawgeti (L, 2, i + 1); + vec[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + i_out = kann_find (k, KANN_F_OUT, 0); + + if (i_out <= 0) { + g_free (vec); + return luaL_error (L, "invalid ANN: output layer is missing or is " + "at the input pos"); + } + + kann_set_batch_size (k, 1); + kann_feed_bind (k, KANN_F_IN, 0, &vec); + kad_eval_at (k->n, k->v, i_out); + + gsize outlen = kad_len (k->v[i_out]); + lua_createtable (L, outlen, 0); + + for (gsize i = 0; i < outlen; i ++) { + lua_pushnumber (L, k->v[i_out]->x[i]); + lua_rawseti (L, -2, i + 1); + } + + g_free (vec); + } + else { + return luaL_error (L, "invalid arguments: rspamd{kann} expected"); + } + + return 1; }
\ No newline at end of file |