aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/kann/kautodiff.h
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/kann/kautodiff.h')
-rw-r--r--contrib/kann/kautodiff.h246
1 files changed, 246 insertions, 0 deletions
diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h
new file mode 100644
index 000000000..a2c648835
--- /dev/null
+++ b/contrib/kann/kautodiff.h
@@ -0,0 +1,246 @@
+/*
+ The MIT License
+
+ Copyright (c) 2018-2019 Dana-Farber Cancer Institute
+ 2016-2018 Broad Institute
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files (the
+ "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to
+ the following conditions:
+
+ The above copyright notice and this permission notice shall be
+ included in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+*/
+
+#ifndef KANN_AUTODIFF_H
+#define KANN_AUTODIFF_H
+
+#define KAD_VERSION "r544"
+
+#include <stdio.h>
+#include <stdint.h>
+
+#ifdef __STRICT_ANSI__
+#define inline
+#endif
+
+#define KAD_MAX_DIM 4 /* max dimension */
+#define KAD_MAX_OP 64 /* max number of operators */
+
+/* A computational graph is a directed acyclic graph. In the graph, an external
+ * node represents a variable, a constant or a feed; an internal node
+ * represents an operator; an edge from node v to w indicates v is an operand
+ * of w.
+ */
+
+#define KAD_VAR 0x1
+#define KAD_CONST 0x2
+#define KAD_POOL 0x4
+#define KAD_SHARE_RNG 0x10 /* with this flag on, different time step shares the same RNG status after unroll */
+
+#define kad_is_back(p) ((p)->flag & KAD_VAR)
+#define kad_is_ext(p) ((p)->n_child == 0)
+#define kad_is_var(p) (kad_is_ext(p) && kad_is_back(p))
+#define kad_is_const(p) (kad_is_ext(p) && ((p)->flag & KAD_CONST))
+#define kad_is_feed(p) (kad_is_ext(p) && !kad_is_back(p) && !((p)->flag & KAD_CONST))
+#define kad_is_pivot(p) ((p)->n_child == 1 && ((p)->flag & KAD_POOL))
+#define kad_is_switch(p) ((p)->op == 12 && !((p)->flag & KAD_POOL))
+#define kad_use_rng(p) ((p)->op == 15 || (p)->op == 24)
+
+#define kad_eval_enable(p) ((p)->tmp = 1)
+#define kad_eval_disable(p) ((p)->tmp = -1)
+
+/* a node in the computational graph */
+typedef struct kad_node_t {
+ uint8_t n_d; /* number of dimensions; no larger than KAD_MAX_DIM */
+ uint8_t flag; /* type of the node; see KAD_F_* for valid flags */
+ uint16_t op; /* operator; kad_op_list[op] is the actual function */
+ int32_t n_child; /* number of operands/child nodes */
+ int32_t tmp; /* temporary field; MUST BE zero before calling kad_compile() */
+ int32_t ptr_size; /* size of ptr below */
+ int32_t d[KAD_MAX_DIM]; /* dimensions */
+ int32_t ext_label; /* labels for external uses (not modified by the kad_* APIs) */
+ uint32_t ext_flag; /* flags for external uses (not modified by the kad_* APIs) */
+ float *x; /* value; allocated for internal nodes */
+ float *g; /* gradient; allocated for internal nodes */
+ void *ptr; /* for special operators that need additional parameters (e.g. conv2d) */
+ void *gtmp; /* temporary data generated at the forward pass but used at the backward pass */
+ struct kad_node_t **child; /* operands/child nodes */
+ struct kad_node_t *pre; /* usually NULL; only used for RNN */
+} kad_node_t, *kad_node_p;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * Compile/linearize a computational graph
+ *
+ * @param n_node number of nodes (out)
+ * @param n_roots number of nodes without predecessors
+ * @param roots list of nodes without predecessors
+ *
+ * @return list of nodes, of size *n_node
+ */
+kad_node_t **kad_compile_array(int *n_node, int n_roots, kad_node_t **roots);
+
+kad_node_t **kad_compile(int *n_node, int n_roots, ...); /* an alternative API to above */
+void kad_delete(int n, kad_node_t **a); /* deallocate a compiled/linearized graph */
+
+/**
+ * Compute the value at a node
+ *
+ * @param n number of nodes
+ * @param a list of nodes
+ * @param from compute the value at this node, 0<=from<n
+ *
+ * @return a pointer to the value (pointing to kad_node_t::x, so don't call
+ * free() on it!)
+ */
+const float *kad_eval_at(int n, kad_node_t **a, int from);
+
+void kad_eval_marked(int n, kad_node_t **a);
+int kad_sync_dim(int n, kad_node_t **v, int batch_size);
+
+/**
+ * Compute gradient
+ *
+ * @param n number of nodes
+ * @param a list of nodes
+ * @param from the function node; must be a scalar (compute \nabla a[from])
+ */
+void kad_grad(int n, kad_node_t **a, int from);
+
+/**
+ * Unroll a recurrent computation graph
+ *
+ * @param n_v number of nodes
+ * @param v list of nodes
+ * @param new_n number of nodes in the unrolled graph (out)
+ * @param len how many times to unroll, one for each pivot
+ *
+ * @return list of nodes in the unrolled graph
+ */
+kad_node_t **kad_unroll(int n_v, kad_node_t **v, int *new_n, int *len);
+int kad_n_pivots(int n_v, kad_node_t **v);
+
+kad_node_t **kad_clone(int n, kad_node_t **v, int batch_size);
+
+/* define a variable, a constant or a feed (placeholder in TensorFlow) */
+kad_node_t *kad_var(float *x, float *g, int n_d, ...); /* a variable; gradients to be computed; not unrolled */
+kad_node_t *kad_const(float *x, int n_d, ...); /* a constant; no gradients computed; not unrolled */
+kad_node_t *kad_feed(int n_d, ...); /* an input/output; no gradients computed; unrolled */
+
+/* operators taking two operands */
+kad_node_t *kad_add(kad_node_t *x, kad_node_t *y); /* f(x,y) = x + y (generalized element-wise addition; f[i*n+j]=x[i*n+j]+y[j], n=kad_len(y), 0<j<n, 0<i<kad_len(x)/n) */
+kad_node_t *kad_sub(kad_node_t *x, kad_node_t *y); /* f(x,y) = x - y (generalized element-wise subtraction) */
+kad_node_t *kad_mul(kad_node_t *x, kad_node_t *y); /* f(x,y) = x * y (generalized element-wise product) */
+
+kad_node_t *kad_matmul(kad_node_t *x, kad_node_t *y); /* f(x,y) = x * y (general matrix product) */
+kad_node_t *kad_cmul(kad_node_t *x, kad_node_t *y); /* f(x,y) = x * y^T (column-wise matrix product; i.e. y is transposed) */
+
+/* loss functions; output scalar */
+kad_node_t *kad_mse(kad_node_t *x, kad_node_t *y); /* mean square error */
+kad_node_t *kad_ce_multi(kad_node_t *x, kad_node_t *y); /* multi-class cross-entropy; x is the preidction and y is the truth */
+kad_node_t *kad_ce_bin(kad_node_t *x, kad_node_t *y); /* binary cross-entropy for (0,1) */
+kad_node_t *kad_ce_bin_neg(kad_node_t *x, kad_node_t *y); /* binary cross-entropy for (-1,1) */
+kad_node_t *kad_ce_multi_weighted(kad_node_t *pred, kad_node_t *truth, kad_node_t *weight);
+
+#define KAD_PAD_NONE 0 /* use the smallest zero-padding */
+#define KAD_PAD_SAME (-2) /* output to have the same dimension as input */
+
+kad_node_t *kad_conv2d(kad_node_t *x, kad_node_t *w, int r_stride, int c_stride, int r_pad, int c_pad); /* 2D convolution with weight matrix flipped */
+kad_node_t *kad_max2d(kad_node_t *x, int kernel_h, int kernel_w, int r_stride, int c_stride, int r_pad, int c_pad); /* 2D max pooling */
+kad_node_t *kad_conv1d(kad_node_t *x, kad_node_t *w, int stride, int pad); /* 1D convolution with weight flipped */
+kad_node_t *kad_max1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D max pooling */
+kad_node_t *kad_avg1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D average pooling */
+
+kad_node_t *kad_dropout(kad_node_t *x, kad_node_t *r); /* dropout at rate r */
+kad_node_t *kad_sample_normal(kad_node_t *x); /* f(x) = x * r, where r is drawn from a standard normal distribution */
+
+/* operators taking one operand */
+kad_node_t *kad_square(kad_node_t *x); /* f(x) = x^2 (element-wise square) */
+kad_node_t *kad_sigm(kad_node_t *x); /* f(x) = 1/(1+exp(-x)) (element-wise sigmoid) */
+kad_node_t *kad_tanh(kad_node_t *x); /* f(x) = (1-exp(-2x)) / (1+exp(-2x)) (element-wise tanh) */
+kad_node_t *kad_relu(kad_node_t *x); /* f(x) = max{0,x} (element-wise rectifier, aka ReLU) */
+kad_node_t *kad_softmax(kad_node_t *x);/* f_i(x_1,...,x_n) = exp(x_i) / \sum_j exp(x_j) (softmax: tf.nn.softmax(x,dim=-1)) */
+kad_node_t *kad_1minus(kad_node_t *x); /* f(x) = 1 - x */
+kad_node_t *kad_exp(kad_node_t *x); /* f(x) = exp(x) */
+kad_node_t *kad_log(kad_node_t *x); /* f(x) = log(x) */
+kad_node_t *kad_sin(kad_node_t *x); /* f(x) = sin(x) */
+
+kad_node_t *kad_stdnorm(kad_node_t *x); /* layer normalization; applied to the last dimension */
+
+/* operators taking an indefinite number of operands (e.g. pooling) */
+kad_node_t *kad_avg(int n, kad_node_t **x); /* f(x_1,...,x_n) = \sum_i x_i/n (mean pooling) */
+kad_node_t *kad_max(int n, kad_node_t **x); /* f(x_1,...,x_n) = max{x_1,...,x_n} (max pooling) */
+kad_node_t *kad_stack(int n, kad_node_t **x); /* f(x_1,...,x_n) = [x_1,...,x_n] (stack pooling) */
+kad_node_t *kad_select(int n, kad_node_t **x, int which); /* f(x_1,...,x_n;i) = x_i (select pooling; -1 for the last) */
+
+/* dimension reduction */
+kad_node_t *kad_reduce_sum(kad_node_t *x, int axis); /* tf.reduce_sum(x, axis) */
+kad_node_t *kad_reduce_mean(kad_node_t *x, int axis); /* tf.reduce_mean(x, axis) */
+
+/* special operators */
+kad_node_t *kad_slice(kad_node_t *x, int axis, int start, int end); /* take a slice on the axis-th dimension */
+kad_node_t *kad_concat(int axis, int n, ...); /* concatenate on the axis-th dimension */
+kad_node_t *kad_concat_array(int axis, int n, kad_node_t **p); /* the array version of concat */
+kad_node_t *kad_reshape(kad_node_t *x, int n_d, int *d); /* reshape; similar behavior to TensorFlow's reshape() */
+kad_node_t *kad_reverse(kad_node_t *x, int axis);
+kad_node_t *kad_switch(int n, kad_node_t **p); /* manually (as a hyperparameter) choose one input, default to 0 */
+
+/* miscellaneous operations on a compiled graph */
+int kad_size_var(int n, kad_node_t *const* v); /* total size of all variables */
+int kad_size_const(int n, kad_node_t *const* v); /* total size of all constants */
+
+/* graph I/O */
+int kad_save(FILE *fp, int n_node, kad_node_t **node);
+kad_node_t **kad_load(FILE *fp, int *_n_node);
+
+/* random number generator */
+void *kad_rng(void);
+void kad_srand(void *d, uint64_t seed);
+uint64_t kad_rand(void *d);
+double kad_drand(void *d);
+double kad_drand_normal(void *d);
+void kad_saxpy(int n, float a, const float *x, float *y);
+
+/* debugging routines */
+void kad_trap_fe(void); /* abort on divide-by-zero and NaN */
+void kad_print_graph(FILE *fp, int n, kad_node_t **v);
+void kad_check_grad(int n, kad_node_t **a, int from);
+
+#ifdef __cplusplus
+}
+#endif
+
+#define KAD_ALLOC 1
+#define KAD_FORWARD 2
+#define KAD_BACKWARD 3
+#define KAD_SYNC_DIM 4
+
+typedef int (*kad_op_f)(kad_node_t*, int);
+extern kad_op_f kad_op_list[KAD_MAX_OP];
+extern char *kad_op_name[KAD_MAX_OP];
+
+static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */
+{
+ int n = 1, i;
+ for (i = 0; i < p->n_d; ++i) n *= p->d[i];
+ return n;
+}
+
+#endif