diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-01 13:30:09 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-01 13:30:09 +0100 |
commit | c5ef059e0d0ea41b8a490c2f838a819e1363d0dd (patch) | |
tree | b0fa00b6b23cac588d58e759341826473e6c615a /contrib/kann | |
parent | a2a3df8b76cbe92bb56f82ae25362e4d0d440f19 (diff) | |
download | rspamd-c5ef059e0d0ea41b8a490c2f838a819e1363d0dd.tar.gz rspamd-c5ef059e0d0ea41b8a490c2f838a819e1363d0dd.zip |
[Project] Add training support to kann
Diffstat (limited to 'contrib/kann')
-rw-r--r-- | contrib/kann/kann.c | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c index 3fbf139cc..43227bdc6 100644 --- a/contrib/kann/kann.c +++ b/contrib/kann/kann.c @@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return kad_node_t *kann_layer_input(int n1) { kad_node_t *t; - t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN; + t = kad_feed(2, 1, n1); + t->ext_flag |= KANN_F_IN; return t; } @@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type) assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE); t = kann_layer_dense(t, n_out); truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH; + if (cost_type == KANN_C_MSE) { cost = kad_mse(t, truth); } else if (cost_type == KANN_C_CEB) { @@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type) t = kad_softmax(t); cost = kad_ce_multi(t, truth); } - t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST; + else { + assert (0); + } + + t->ext_flag |= KANN_F_OUT; + cost->ext_flag |= KANN_F_COST; + return cost; } |