|
|
@@ -846,7 +846,10 @@ float kann_grad_clip(float thres, int n, float *g) |
|
|
|
*** @@XY: simpler API for network with a single input/output *** |
|
|
|
****************************************************************/ |
|
|
|
|
|
|
|
int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max_drop_streak, float frac_val, int n, float **_x, float **_y) |
|
|
|
int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, |
|
|
|
int max_drop_streak, float frac_val, int n, |
|
|
|
float **_x, float **_y, |
|
|
|
kann_train_cb cb, void *ud) |
|
|
|
{ |
|
|
|
int i, j, *shuf, n_train, n_val, n_in, n_out, n_var, n_const, drop_streak = 0, min_set = 0; |
|
|
|
float **x, **y, *x1, *y1, *r, min_val_cost = FLT_MAX, *min_x, *min_c; |
|
|
@@ -907,7 +910,9 @@ int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max |
|
|
|
n_proc += ms; |
|
|
|
} |
|
|
|
if (n_val > 0) val_cost /= n_val; |
|
|
|
if (kann_verbose >= 3) { |
|
|
|
if (cb) { |
|
|
|
cb(i + 1, train_cost, val_cost, ud); |
|
|
|
#if 0 |
|
|
|
fprintf(stderr, "epoch: %d; training cost: %g", i+1, train_cost); |
|
|
|
if (n_train_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_train_err / n_train); |
|
|
|
if (n_val > 0) { |
|
|
@@ -915,6 +920,7 @@ int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max |
|
|
|
if (n_val_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_val_err / n_val); |
|
|
|
} |
|
|
|
fputc('\n', stderr); |
|
|
|
#endif |
|
|
|
} |
|
|
|
if (i >= max_drop_streak && n_val > 0) { |
|
|
|
if (val_cost < min_val_cost) { |