/* The MIT License Copyright (c) 2018-2019 Dana-Farber Cancer Institute 2016-2018 Broad Institute Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef KANN_AUTODIFF_H #define KANN_AUTODIFF_H #define KAD_VERSION "r544" #include #include #ifdef __STRICT_ANSI__ #define inline #endif #define KAD_MAX_DIM 4 /* max dimension */ #define KAD_MAX_OP 64 /* max number of operators */ /* A computational graph is a directed acyclic graph. In the graph, an external * node represents a variable, a constant or a feed; an internal node * represents an operator; an edge from node v to w indicates v is an operand * of w. */ #define KAD_VAR 0x1 #define KAD_CONST 0x2 #define KAD_POOL 0x4 #define KAD_SHARE_RNG 0x10 /* with this flag on, different time step shares the same RNG status after unroll */ #define kad_is_back(p) ((p)->flag & KAD_VAR) #define kad_is_ext(p) ((p)->n_child == 0) #define kad_is_var(p) (kad_is_ext(p) && kad_is_back(p)) #define kad_is_const(p) (kad_is_ext(p) && ((p)->flag & KAD_CONST)) #define kad_is_feed(p) (kad_is_ext(p) && !kad_is_back(p) && !((p)->flag & KAD_CONST)) #define kad_is_pivot(p) ((p)->n_child == 1 && ((p)->flag & KAD_POOL)) #define kad_is_switch(p) ((p)->op == 12 && !((p)->flag & KAD_POOL)) #define kad_use_rng(p) ((p)->op == 15 || (p)->op == 24) #define kad_eval_enable(p) ((p)->tmp = 1) #define kad_eval_disable(p) ((p)->tmp = -1) /* a node in the computational graph */ typedef struct kad_node_t { uint8_t n_d; /* number of dimensions; no larger than KAD_MAX_DIM */ uint8_t flag; /* type of the node; see KAD_F_* for valid flags */ uint16_t op; /* operator; kad_op_list[op] is the actual function */ int32_t n_child; /* number of operands/child nodes */ int32_t tmp; /* temporary field; MUST BE zero before calling kad_compile() */ int32_t ptr_size; /* size of ptr below */ int32_t d[KAD_MAX_DIM]; /* dimensions */ int32_t ext_label; /* labels for external uses (not modified by the kad_* APIs) */ uint32_t ext_flag; /* flags for external uses (not modified by the kad_* APIs) */ float *x; /* value; allocated for internal nodes */ float *g; /* gradient; allocated for internal nodes */ void *ptr; /* for special operators that need additional parameters (e.g. conv2d) */ void *gtmp; /* temporary data generated at the forward pass but used at the backward pass */ struct kad_node_t **child; /* operands/child nodes */ struct kad_node_t *pre; /* usually NULL; only used for RNN */ } kad_node_t, *kad_node_p; #ifdef __cplusplus extern "C" { #endif /** * Compile/linearize a computational graph * * @param n_node number of nodes (out) * @param n_roots number of nodes without predecessors * @param roots list of nodes without predecessors * * @return list of nodes, of size *n_node */ kad_node_t **kad_compile_array(int *n_node, int n_roots, kad_node_t **roots); kad_node_t **kad_compile(int *n_node, int n_roots, ...); /* an alternative API to above */ void kad_delete(int n, kad_node_t **a); /* deallocate a compiled/linearized graph */ /** * Compute the value at a node * * @param n number of nodes * @param a list of nodes * @param from compute the value at this node, 0<=fromx */ { int n = 1, i; for (i = 0; i < p->n_d; ++i) n *= p->d[i]; return n; } #endif