aboutsummaryrefslogtreecommitdiffstats
path: root/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'contrib')
-rw-r--r--contrib/kann/kann.c12
1 files changed, 10 insertions, 2 deletions
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c
index 3fbf139cc..43227bdc6 100644
--- a/contrib/kann/kann.c
+++ b/contrib/kann/kann.c
@@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return
kad_node_t *kann_layer_input(int n1)
{
kad_node_t *t;
- t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN;
+ t = kad_feed(2, 1, n1);
+ t->ext_flag |= KANN_F_IN;
return t;
}
@@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
t = kann_layer_dense(t, n_out);
truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
+
if (cost_type == KANN_C_MSE) {
cost = kad_mse(t, truth);
} else if (cost_type == KANN_C_CEB) {
@@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
t = kad_softmax(t);
cost = kad_ce_multi(t, truth);
}
- t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST;
+ else {
+ assert (0);
+ }
+
+ t->ext_flag |= KANN_F_OUT;
+ cost->ext_flag |= KANN_F_COST;
+
return cost;
}