You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kautodiff.h 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /*
  2. The MIT License
  3. Copyright (c) 2018-2019 Dana-Farber Cancer Institute
  4. 2016-2018 Broad Institute
  5. Permission is hereby granted, free of charge, to any person obtaining
  6. a copy of this software and associated documentation files (the
  7. "Software"), to deal in the Software without restriction, including
  8. without limitation the rights to use, copy, modify, merge, publish,
  9. distribute, sublicense, and/or sell copies of the Software, and to
  10. permit persons to whom the Software is furnished to do so, subject to
  11. the following conditions:
  12. The above copyright notice and this permission notice shall be
  13. included in all copies or substantial portions of the Software.
  14. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  15. EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  16. MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  17. NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  18. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  19. ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. SOFTWARE.
  22. */
  23. #ifndef KANN_AUTODIFF_H
  24. #define KANN_AUTODIFF_H
  25. #define KAD_VERSION "r544"
  26. #include <stdio.h>
  27. #include <stdint.h>
  28. #ifdef __STRICT_ANSI__
  29. #define inline
  30. #endif
  31. #define KAD_MAX_DIM 4 /* max dimension */
  32. #define KAD_MAX_OP 64 /* max number of operators */
  33. /* A computational graph is a directed acyclic graph. In the graph, an external
  34. * node represents a variable, a constant or a feed; an internal node
  35. * represents an operator; an edge from node v to w indicates v is an operand
  36. * of w.
  37. */
  38. #define KAD_VAR 0x1
  39. #define KAD_CONST 0x2
  40. #define KAD_POOL 0x4
  41. #define KAD_SHARE_RNG 0x10 /* with this flag on, different time step shares the same RNG status after unroll */
  42. #define kad_is_back(p) ((p)->flag & KAD_VAR)
  43. #define kad_is_ext(p) ((p)->n_child == 0)
  44. #define kad_is_var(p) (kad_is_ext(p) && kad_is_back(p))
  45. #define kad_is_const(p) (kad_is_ext(p) && ((p)->flag & KAD_CONST))
  46. #define kad_is_feed(p) (kad_is_ext(p) && !kad_is_back(p) && !((p)->flag & KAD_CONST))
  47. #define kad_is_pivot(p) ((p)->n_child == 1 && ((p)->flag & KAD_POOL))
  48. #define kad_is_switch(p) ((p)->op == 12 && !((p)->flag & KAD_POOL))
  49. #define kad_use_rng(p) ((p)->op == 15 || (p)->op == 24)
  50. #define kad_eval_enable(p) ((p)->tmp = 1)
  51. #define kad_eval_disable(p) ((p)->tmp = -1)
  52. /* a node in the computational graph */
  53. typedef struct kad_node_t {
  54. uint8_t n_d; /* number of dimensions; no larger than KAD_MAX_DIM */
  55. uint8_t flag; /* type of the node; see KAD_F_* for valid flags */
  56. uint16_t op; /* operator; kad_op_list[op] is the actual function */
  57. int32_t n_child; /* number of operands/child nodes */
  58. int32_t tmp; /* temporary field; MUST BE zero before calling kad_compile() */
  59. int32_t ptr_size; /* size of ptr below */
  60. int32_t d[KAD_MAX_DIM]; /* dimensions */
  61. int32_t ext_label; /* labels for external uses (not modified by the kad_* APIs) */
  62. uint32_t ext_flag; /* flags for external uses (not modified by the kad_* APIs) */
  63. float *x; /* value; allocated for internal nodes */
  64. float *g; /* gradient; allocated for internal nodes */
  65. void *ptr; /* for special operators that need additional parameters (e.g. conv2d) */
  66. void *gtmp; /* temporary data generated at the forward pass but used at the backward pass */
  67. struct kad_node_t **child; /* operands/child nodes */
  68. struct kad_node_t *pre; /* usually NULL; only used for RNN */
  69. } kad_node_t, *kad_node_p;
  70. #ifdef __cplusplus
  71. extern "C" {
  72. #endif
  73. /**
  74. * Compile/linearize a computational graph
  75. *
  76. * @param n_node number of nodes (out)
  77. * @param n_roots number of nodes without predecessors
  78. * @param roots list of nodes without predecessors
  79. *
  80. * @return list of nodes, of size *n_node
  81. */
  82. kad_node_t **kad_compile_array(int *n_node, int n_roots, kad_node_t **roots);
  83. kad_node_t **kad_compile(int *n_node, int n_roots, ...); /* an alternative API to above */
  84. void kad_delete(int n, kad_node_t **a); /* deallocate a compiled/linearized graph */
  85. /**
  86. * Compute the value at a node
  87. *
  88. * @param n number of nodes
  89. * @param a list of nodes
  90. * @param from compute the value at this node, 0<=from<n
  91. *
  92. * @return a pointer to the value (pointing to kad_node_t::x, so don't call
  93. * free() on it!)
  94. */
  95. const float *kad_eval_at(int n, kad_node_t **a, int from);
  96. void kad_eval_marked(int n, kad_node_t **a);
  97. int kad_sync_dim(int n, kad_node_t **v, int batch_size);
  98. /**
  99. * Compute gradient
  100. *
  101. * @param n number of nodes
  102. * @param a list of nodes
  103. * @param from the function node; must be a scalar (compute \nabla a[from])
  104. */
  105. void kad_grad(int n, kad_node_t **a, int from);
  106. /**
  107. * Unroll a recurrent computation graph
  108. *
  109. * @param n_v number of nodes
  110. * @param v list of nodes
  111. * @param new_n number of nodes in the unrolled graph (out)
  112. * @param len how many times to unroll, one for each pivot
  113. *
  114. * @return list of nodes in the unrolled graph
  115. */
  116. kad_node_t **kad_unroll(int n_v, kad_node_t **v, int *new_n, int *len);
  117. int kad_n_pivots(int n_v, kad_node_t **v);
  118. kad_node_t **kad_clone(int n, kad_node_t **v, int batch_size);
  119. /* define a variable, a constant or a feed (placeholder in TensorFlow) */
  120. kad_node_t *kad_var(float *x, float *g, int n_d, ...); /* a variable; gradients to be computed; not unrolled */
  121. kad_node_t *kad_const(float *x, int n_d, ...); /* a constant; no gradients computed; not unrolled */
  122. kad_node_t *kad_feed(int n_d, ...); /* an input/output; no gradients computed; unrolled */
  123. /* operators taking two operands */
  124. kad_node_t *kad_add(kad_node_t *x, kad_node_t *y); /* f(x,y) = x + y (generalized element-wise addition; f[i*n+j]=x[i*n+j]+y[j], n=kad_len(y), 0<j<n, 0<i<kad_len(x)/n) */
  125. kad_node_t *kad_sub(kad_node_t *x, kad_node_t *y); /* f(x,y) = x - y (generalized element-wise subtraction) */
  126. kad_node_t *kad_mul(kad_node_t *x, kad_node_t *y); /* f(x,y) = x * y (generalized element-wise product) */
  127. kad_node_t *kad_matmul(kad_node_t *x, kad_node_t *y); /* f(x,y) = x * y (general matrix product) */
  128. kad_node_t *kad_cmul(kad_node_t *x, kad_node_t *y); /* f(x,y) = x * y^T (column-wise matrix product; i.e. y is transposed) */
  129. /* loss functions; output scalar */
  130. kad_node_t *kad_mse(kad_node_t *x, kad_node_t *y); /* mean square error */
  131. kad_node_t *kad_ce_multi(kad_node_t *x, kad_node_t *y); /* multi-class cross-entropy; x is the preidction and y is the truth */
  132. kad_node_t *kad_ce_bin(kad_node_t *x, kad_node_t *y); /* binary cross-entropy for (0,1) */
  133. kad_node_t *kad_ce_bin_neg(kad_node_t *x, kad_node_t *y); /* binary cross-entropy for (-1,1) */
  134. kad_node_t *kad_ce_multi_weighted(kad_node_t *pred, kad_node_t *truth, kad_node_t *weight);
  135. #define KAD_PAD_NONE 0 /* use the smallest zero-padding */
  136. #define KAD_PAD_SAME (-2) /* output to have the same dimension as input */
  137. kad_node_t *kad_conv2d(kad_node_t *x, kad_node_t *w, int r_stride, int c_stride, int r_pad, int c_pad); /* 2D convolution with weight matrix flipped */
  138. kad_node_t *kad_max2d(kad_node_t *x, int kernel_h, int kernel_w, int r_stride, int c_stride, int r_pad, int c_pad); /* 2D max pooling */
  139. kad_node_t *kad_conv1d(kad_node_t *x, kad_node_t *w, int stride, int pad); /* 1D convolution with weight flipped */
  140. kad_node_t *kad_max1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D max pooling */
  141. kad_node_t *kad_avg1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D average pooling */
  142. kad_node_t *kad_dropout(kad_node_t *x, kad_node_t *r); /* dropout at rate r */
  143. kad_node_t *kad_sample_normal(kad_node_t *x); /* f(x) = x * r, where r is drawn from a standard normal distribution */
  144. /* operators taking one operand */
  145. kad_node_t *kad_square(kad_node_t *x); /* f(x) = x^2 (element-wise square) */
  146. kad_node_t *kad_sigm(kad_node_t *x); /* f(x) = 1/(1+exp(-x)) (element-wise sigmoid) */
  147. kad_node_t *kad_tanh(kad_node_t *x); /* f(x) = (1-exp(-2x)) / (1+exp(-2x)) (element-wise tanh) */
  148. kad_node_t *kad_relu(kad_node_t *x); /* f(x) = max{0,x} (element-wise rectifier, aka ReLU) */
  149. kad_node_t *kad_softmax(kad_node_t *x);/* f_i(x_1,...,x_n) = exp(x_i) / \sum_j exp(x_j) (softmax: tf.nn.softmax(x,dim=-1)) */
  150. kad_node_t *kad_1minus(kad_node_t *x); /* f(x) = 1 - x */
  151. kad_node_t *kad_exp(kad_node_t *x); /* f(x) = exp(x) */
  152. kad_node_t *kad_log(kad_node_t *x); /* f(x) = log(x) */
  153. kad_node_t *kad_sin(kad_node_t *x); /* f(x) = sin(x) */
  154. kad_node_t *kad_stdnorm(kad_node_t *x); /* layer normalization; applied to the last dimension */
  155. /* operators taking an indefinite number of operands (e.g. pooling) */
  156. kad_node_t *kad_avg(int n, kad_node_t **x); /* f(x_1,...,x_n) = \sum_i x_i/n (mean pooling) */
  157. kad_node_t *kad_max(int n, kad_node_t **x); /* f(x_1,...,x_n) = max{x_1,...,x_n} (max pooling) */
  158. kad_node_t *kad_stack(int n, kad_node_t **x); /* f(x_1,...,x_n) = [x_1,...,x_n] (stack pooling) */
  159. kad_node_t *kad_select(int n, kad_node_t **x, int which); /* f(x_1,...,x_n;i) = x_i (select pooling; -1 for the last) */
  160. /* dimension reduction */
  161. kad_node_t *kad_reduce_sum(kad_node_t *x, int axis); /* tf.reduce_sum(x, axis) */
  162. kad_node_t *kad_reduce_mean(kad_node_t *x, int axis); /* tf.reduce_mean(x, axis) */
  163. /* special operators */
  164. kad_node_t *kad_slice(kad_node_t *x, int axis, int start, int end); /* take a slice on the axis-th dimension */
  165. kad_node_t *kad_concat(int axis, int n, ...); /* concatenate on the axis-th dimension */
  166. kad_node_t *kad_concat_array(int axis, int n, kad_node_t **p); /* the array version of concat */
  167. kad_node_t *kad_reshape(kad_node_t *x, int n_d, int *d); /* reshape; similar behavior to TensorFlow's reshape() */
  168. kad_node_t *kad_reverse(kad_node_t *x, int axis);
  169. kad_node_t *kad_switch(int n, kad_node_t **p); /* manually (as a hyperparameter) choose one input, default to 0 */
  170. /* miscellaneous operations on a compiled graph */
  171. int kad_size_var(int n, kad_node_t *const* v); /* total size of all variables */
  172. int kad_size_const(int n, kad_node_t *const* v); /* total size of all constants */
  173. /* graph I/O */
  174. int kad_save(FILE *fp, int n_node, kad_node_t **node);
  175. kad_node_t **kad_load(FILE *fp, int *_n_node);
  176. /* random number generator */
  177. void *kad_rng(void);
  178. void kad_srand(void *d, uint64_t seed);
  179. uint64_t kad_rand(void *d);
  180. double kad_drand(void *d);
  181. double kad_drand_normal(void *d);
  182. void kad_saxpy(int n, float a, const float *x, float *y);
  183. /* debugging routines */
  184. void kad_trap_fe(void); /* abort on divide-by-zero and NaN */
  185. void kad_print_graph(FILE *fp, int n, kad_node_t **v);
  186. void kad_check_grad(int n, kad_node_t **a, int from);
  187. #ifdef __cplusplus
  188. }
  189. #endif
  190. #define KAD_ALLOC 1
  191. #define KAD_FORWARD 2
  192. #define KAD_BACKWARD 3
  193. #define KAD_SYNC_DIM 4
  194. typedef int (*kad_op_f)(kad_node_t*, int);
  195. extern kad_op_f kad_op_list[KAD_MAX_OP];
  196. extern char *kad_op_name[KAD_MAX_OP];
  197. static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */
  198. {
  199. int n = 1, i;
  200. for (i = 0; i < p->n_d; ++i) n *= p->d[i];
  201. return n;
  202. }
  203. /* Additions by Rspamd */
  204. void kad_sgemm_simple (int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
  205. /**
  206. * Calculate eigenvectors and eigenvalues
  207. * @param N dimensions of A (must be NxN)
  208. * @param A input matrix (part of it will be destroyed, so copy if needed), on finish the first `nwork` columns will have eigenvectors
  209. * @param eigenvals eigenvalues, must be N elements vector
  210. */
  211. bool kad_ssyev_simple (int N, float *A, float *eigenvals);
  212. #endif