aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/kann
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-01 13:30:09 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-01 13:30:09 +0100
commitc5ef059e0d0ea41b8a490c2f838a819e1363d0dd (patch)
treeb0fa00b6b23cac588d58e759341826473e6c615a /contrib/kann
parenta2a3df8b76cbe92bb56f82ae25362e4d0d440f19 (diff)
downloadrspamd-c5ef059e0d0ea41b8a490c2f838a819e1363d0dd.tar.gz
rspamd-c5ef059e0d0ea41b8a490c2f838a819e1363d0dd.zip
[Project] Add training support to kann
Diffstat (limited to 'contrib/kann')
-rw-r--r--contrib/kann/kann.c12
1 files changed, 10 insertions, 2 deletions
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c
index 3fbf139cc..43227bdc6 100644
--- a/contrib/kann/kann.c
+++ b/contrib/kann/kann.c
@@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return
kad_node_t *kann_layer_input(int n1)
{
kad_node_t *t;
- t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN;
+ t = kad_feed(2, 1, n1);
+ t->ext_flag |= KANN_F_IN;
return t;
}
@@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
t = kann_layer_dense(t, n_out);
truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
+
if (cost_type == KANN_C_MSE) {
cost = kad_mse(t, truth);
} else if (cost_type == KANN_C_CEB) {
@@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
t = kad_softmax(t);
cost = kad_ce_multi(t, truth);
}
- t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST;
+ else {
+ assert (0);
+ }
+
+ t->ext_flag |= KANN_F_OUT;
+ cost->ext_flag |= KANN_F_COST;
+
return cost;
}