diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-06-30 09:40:44 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-06-30 09:40:44 +0100 |
commit | 95edae6494dac4acf6ab19714a45339e515b8c49 (patch) | |
tree | cd744272ce0518e0ce06f3d5bd7bbde7e7110ec3 /contrib/kann | |
parent | b0dc1504eb3f788c698001a36b320bdd41a5d287 (diff) | |
download | rspamd-95edae6494dac4acf6ab19714a45339e515b8c49.tar.gz rspamd-95edae6494dac4acf6ab19714a45339e515b8c49.zip |
[Project] Support callback function for train
Diffstat (limited to 'contrib/kann')
-rw-r--r-- | contrib/kann/kann.c | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c index 0af15fb2a..3fbf139cc 100644 --- a/contrib/kann/kann.c +++ b/contrib/kann/kann.c @@ -846,7 +846,10 @@ float kann_grad_clip(float thres, int n, float *g) *** @@XY: simpler API for network with a single input/output *** ****************************************************************/ -int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max_drop_streak, float frac_val, int n, float **_x, float **_y) +int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, + int max_drop_streak, float frac_val, int n, + float **_x, float **_y, + kann_train_cb cb, void *ud) { int i, j, *shuf, n_train, n_val, n_in, n_out, n_var, n_const, drop_streak = 0, min_set = 0; float **x, **y, *x1, *y1, *r, min_val_cost = FLT_MAX, *min_x, *min_c; @@ -907,7 +910,9 @@ int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max n_proc += ms; } if (n_val > 0) val_cost /= n_val; - if (kann_verbose >= 3) { + if (cb) { + cb(i + 1, train_cost, val_cost, ud); +#if 0 fprintf(stderr, "epoch: %d; training cost: %g", i+1, train_cost); if (n_train_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_train_err / n_train); if (n_val > 0) { @@ -915,6 +920,7 @@ int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max if (n_val_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_val_err / n_val); } fputc('\n', stderr); +#endif } if (i >= max_drop_streak && n_val > 0) { if (val_cost < min_val_cost) { |