diff options
Diffstat (limited to 'contrib')
-rw-r--r-- | contrib/kann/kann.c | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c index 3fbf139cc..43227bdc6 100644 --- a/contrib/kann/kann.c +++ b/contrib/kann/kann.c @@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return kad_node_t *kann_layer_input(int n1) { kad_node_t *t; - t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN; + t = kad_feed(2, 1, n1); + t->ext_flag |= KANN_F_IN; return t; } @@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type) assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE); t = kann_layer_dense(t, n_out); truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH; + if (cost_type == KANN_C_MSE) { cost = kad_mse(t, truth); } else if (cost_type == KANN_C_CEB) { @@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type) t = kad_softmax(t); cost = kad_ce_multi(t, truth); } - t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST; + else { + assert (0); + } + + t->ext_flag |= KANN_F_OUT; + cost->ext_flag |= KANN_F_COST; + return cost; } |