From 083e6ac5ce374e1e9759c7998dd04b9525333eb4 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sun, 30 Jun 2019 09:40:58 +0100 Subject: [Project] Add simple forward propagation function --- contrib/kann/kann.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'contrib/kann/kann.h') diff --git a/contrib/kann/kann.h b/contrib/kann/kann.h index 7ec748561..af0de5fba 100644 --- a/contrib/kann/kann.h +++ b/contrib/kann/kann.h @@ -220,7 +220,10 @@ kad_node_t *kann_layer_rnn2(int *offset, kad_node_t **par, kad_node_t *in, kad_n kad_node_t *kann_layer_gru2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag); /* operations on network with a single input node and a single output node */ -int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max_drop_streak, float frac_val, int n, float **_x, float **_y); +typedef void (*kann_train_cb)(int iter, float train_cost, float val_cost, void *ud); +int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, + int max_drop_streak, float frac_val, int n, + float **_x, float **_y, kann_train_cb cb, void *ud); float kann_cost_fnn1(kann_t *a, int n, float **x, float **y); const float *kann_apply1(kann_t *a, float *x); -- cgit v1.2.3