aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/kann/kann.h
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-06-27 15:38:34 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-06-27 15:38:34 +0100
commit44e393f9fe9a86bd99ebc5cfcddfe8eb50c8813e (patch)
treeacbea1a0149b064c8de14502e11777192a012b44 /contrib/kann/kann.h
parentb066f62bfafd0e3dc0ebb181e0990fede4f32d47 (diff)
downloadrspamd-44e393f9fe9a86bd99ebc5cfcddfe8eb50c8813e.tar.gz
rspamd-44e393f9fe9a86bd99ebc5cfcddfe8eb50c8813e.zip
[Project] Add kann library to start torch removal
Diffstat (limited to 'contrib/kann/kann.h')
-rw-r--r--contrib/kann/kann.h235
1 files changed, 235 insertions, 0 deletions
diff --git a/contrib/kann/kann.h b/contrib/kann/kann.h
new file mode 100644
index 000000000..1605e5ea5
--- /dev/null
+++ b/contrib/kann/kann.h
@@ -0,0 +1,235 @@
+/*
+ The MIT License
+
+ Copyright (c) 2018-2019 Dana-Farber Cancer Institute
+ 2016-2018 Broad Institute
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files (the
+ "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to
+ the following conditions:
+
+ The above copyright notice and this permission notice shall be
+ included in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+*/
+
+#ifndef KANN_H
+#define KANN_H
+
+#define KANN_VERSION "r536"
+
+#define KANN_F_IN 0x1 /* input */
+#define KANN_F_OUT 0x2 /* output */
+#define KANN_F_TRUTH 0x4 /* truth output */
+#define KANN_F_COST 0x8 /* final cost */
+
+#define KANN_C_CEB 1 /* binary cross-entropy cost, used with sigmoid */
+#define KANN_C_CEM 2 /* multi-class cross-entropy cost, used with softmax */
+#define KANN_C_CEB_NEG 3 /* binary cross-enytopy-like cost, used with tanh */
+#define KANN_C_MSE 4 /* mean square error */
+
+#define KANN_RNN_VAR_H0 0x1 /* take the initial hidden values as variables */
+#define KANN_RNN_NORM 0x2 /* apply layer normalization */
+
+#include "kautodiff.h"
+
+typedef struct {
+ int n; /* number of nodes in the computational graph */
+ kad_node_t **v; /* list of nodes */
+ float *x, *g, *c; /* collated variable values, gradients and constant values */
+ void *mt; /* auxiliary data for multi-threading; NULL if multi-threading disabled */
+} kann_t;
+
+extern int kann_verbose;
+
+#define kann_size_var(a) kad_size_var((a)->n, (a)->v)
+#define kann_size_const(a) kad_size_const((a)->n, (a)->v)
+#define kann_dim_in(a) kann_feed_dim((a), KANN_F_IN, 0)
+#define kann_dim_out(a) kann_feed_dim((a), KANN_F_TRUTH, 0)
+#define kann_srand(seed) kad_srand(0, (seed))
+#define kann_drand() kad_drand(0)
+#define kann_set_batch_size(ann, B) kad_sync_dim((ann)->n, (ann)->v, (B))
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * Generate a network from a computational graph
+ *
+ * A network must have at least one scalar cost node (i.e. whose n_d==0). It
+ * may optionally contain other cost nodes or output nodes not leading to the
+ * primary cost node.
+ *
+ * @param cost cost node (must be a scalar, i.e. cost->n_d==0)
+ * @param n_rest number of other nodes without predecessors
+ * @param ... other nodes (of type kad_node_t*) without predecessors
+ *
+ * @return network on success, or NULL otherwise
+ */
+kann_t *kann_new(kad_node_t *cost, int n_rest, ...);
+
+/**
+ * Unroll an RNN
+ *
+ * @param a network
+ * @param len number of unrolls
+ *
+ * @return an unrolled network, or NULL if the network is not an RNN
+ */
+kann_t *kann_unroll(kann_t *a, ...);
+
+kann_t *kann_unroll_array(kann_t *a, int *len);
+kann_t *kann_clone(kann_t *a, int batch_size);
+void kann_delete(kann_t *a); /* delete a network generated by kann_new() or kann_layer_final() */
+void kann_delete_unrolled(kann_t *a); /* delete a network generated by kann_unroll() */
+
+/**
+ * Enable/disable multi-threading (requiring pthread)
+ *
+ * KANN splits a mini-batch to $n_threads mini-mini-batches and puts each of
+ * them on one thread. So far, only kann_cost() takes the advantage of
+ * multi-threading.
+ *
+ * @param ann network
+ * @param n_threads number of threads; <=1 to completely disable multi-threading
+ * @param max_batch_size max mini-batch size; shall no smaller than n_threads
+ */
+void kann_mt(kann_t *ann, int n_threads, int max_batch_size);
+
+/**
+ * Bind float arrays to feed nodes
+ *
+ * @param a network
+ * @param ext_flag required external flags
+ * @param ext_label required external label
+ * @param x pointers (size equal to the number of matching feed nodes)
+ *
+ * @return number of matching feed nodes
+ */
+int kann_feed_bind(kann_t *a, uint32_t ext_flag, int32_t ext_label, float **x);
+
+/**
+ * Compute the cost and optionally gradients
+ *
+ * @param a network
+ * @param cost_label required external label
+ * @param cal_grad whether to compute gradients
+ *
+ * @return cost
+ */
+float kann_cost(kann_t *a, int cost_label, int cal_grad);
+
+int kann_eval(kann_t *a, uint32_t ext_flag, int ext_label);
+int kann_eval_out(kann_t *a);
+int kann_class_error(const kann_t *ann, int *base);
+
+/**
+ * Find a node
+ *
+ * @param a network
+ * @param ext_flag required external flags; set to 0 to match all flags
+ * @param ext_label required external label
+ *
+ * @return >=0 if found; -1 if not found; -2 if found multiple
+ */
+int kann_find(const kann_t *a, uint32_t ext_flag, int32_t ext_label);
+
+/**
+ * Get the size of a feed node, assuming mini-batch size 1
+ *
+ * @param a network
+ * @param ext_flag required external flags
+ * @param ext_label required external label
+ *
+ * @return size>=0; -1 if not found; -2 if found multiple
+ */
+int kann_feed_dim(const kann_t *a, uint32_t ext_flag, int32_t ext_label);
+
+/**
+ * Get an RNN ready for continuous feeding
+ *
+ * @param a network
+ */
+void kann_rnn_start(kann_t *a);
+
+void kann_rnn_end(kann_t *a);
+
+/**
+ * Switch between training and prediction networks (effective only when there are switch nodes)
+ *
+ * @param a network
+ * @param is_train 0 for prediction network and non-zero for training net
+ */
+void kann_switch(kann_t *a, int is_train);
+
+/**
+ * RMSprop update
+ *
+ * @param n number of variables
+ * @param h0 learning rate
+ * @param h per-variable learning rate; NULL if not applicable
+ * @param decay RMSprop decay; use 0.9 if unsure
+ * @param g gradient, of size n
+ * @param t variables to change
+ * @param r memory, of size n
+ */
+void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g, float *t, float *r);
+
+void kann_shuffle(int n, int *s);
+float kann_grad_clip(float thres, int n, float *g);
+
+/* common layers */
+kad_node_t *kann_layer_input(int n1);
+kad_node_t *kann_layer_dense(kad_node_t *in, int n1);
+kad_node_t *kann_layer_dropout(kad_node_t *t, float r);
+kad_node_t *kann_layer_layernorm(kad_node_t *in);
+kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag);
+kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag);
+kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag);
+kad_node_t *kann_layer_conv2d(kad_node_t *in, int n_flt, int k_rows, int k_cols, int stride_r, int stride_c, int pad_r, int pad_c);
+kad_node_t *kann_layer_conv1d(kad_node_t *in, int n_flt, int k_size, int stride, int pad);
+kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type);
+
+kad_node_t *kann_new_leaf(uint8_t flag, float x0_01, int n_d, ...); /* flag can be KAD_CONST or KAD_VAR */
+kad_node_t *kann_new_scalar(uint8_t flag, float x);
+kad_node_t *kann_new_weight(int n_row, int n_col);
+kad_node_t *kann_new_bias(int n);
+kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col);
+kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len);
+
+kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...);
+kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1);
+kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r);
+kad_node_t *kann_layer_layernorm2(int *offset, kad_node_t **par, kad_node_t *in);
+kad_node_t *kann_layer_rnn2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag);
+kad_node_t *kann_layer_gru2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag);
+
+/* operations on network with a single input node and a single output node */
+int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max_drop_streak, float frac_val, int n, float **_x, float **_y);
+float kann_cost_fnn1(kann_t *a, int n, float **x, float **y);
+const float *kann_apply1(kann_t *a, float *x);
+
+/* model I/O */
+void kann_save_fp(FILE *fp, kann_t *ann);
+void kann_save(const char *fn, kann_t *ann);
+kann_t *kann_load_fp(FILE *fp);
+kann_t *kann_load(const char *fn);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif