You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kann.c 30KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. #include "config.h"
  2. #include <math.h>
  3. #include <float.h>
  4. #include <string.h>
  5. #include <stdlib.h>
  6. #include <assert.h>
  7. #include <stdarg.h>
  8. #include "kann.h"
  9. int kann_verbose = 3;
  10. /******************************************
  11. *** @@BASIC: fundamental KANN routines ***
  12. ******************************************/
  13. static void kad_ext_collate(int n, kad_node_t **a, float **_x, float **_g, float **_c)
  14. {
  15. int i, j, k, l, n_var;
  16. float *x, *g, *c;
  17. n_var = kad_size_var(n, a);
  18. x = *_x = (float*)realloc(*_x, n_var * sizeof(float));
  19. g = *_g = (float*)realloc(*_g, n_var * sizeof(float));
  20. c = *_c = (float*)realloc(*_c, kad_size_const(n, a) * sizeof(float));
  21. memset(g, 0, n_var * sizeof(float));
  22. for (i = j = k = 0; i < n; ++i) {
  23. kad_node_t *v = a[i];
  24. if (kad_is_var(v)) {
  25. l = kad_len(v);
  26. memcpy(&x[j], v->x, l * sizeof(float));
  27. free(v->x);
  28. v->x = &x[j];
  29. v->g = &g[j];
  30. j += l;
  31. } else if (kad_is_const(v)) {
  32. l = kad_len(v);
  33. memcpy(&c[k], v->x, l * sizeof(float));
  34. free(v->x);
  35. v->x = &c[k];
  36. k += l;
  37. }
  38. }
  39. }
  40. static void kad_ext_sync(int n, kad_node_t **a, float *x, float *g, float *c)
  41. {
  42. int i, j, k;
  43. for (i = j = k = 0; i < n; ++i) {
  44. kad_node_t *v = a[i];
  45. if (kad_is_var(v)) {
  46. v->x = &x[j];
  47. v->g = &g[j];
  48. j += kad_len(v);
  49. } else if (kad_is_const(v)) {
  50. v->x = &c[k];
  51. k += kad_len(v);
  52. }
  53. }
  54. }
  55. kann_t *kann_new(kad_node_t *cost, int n_rest, ...)
  56. {
  57. kann_t *a;
  58. int i, n_roots = 1 + n_rest, has_pivot = 0, has_recur = 0;
  59. kad_node_t **roots;
  60. va_list ap;
  61. if (cost->n_d != 0) return 0;
  62. va_start(ap, n_rest);
  63. roots = (kad_node_t**)malloc((n_roots + 1) * sizeof(kad_node_t*));
  64. for (i = 0; i < n_rest; ++i)
  65. roots[i] = va_arg(ap, kad_node_t*);
  66. roots[i++] = cost;
  67. va_end(ap);
  68. cost->ext_flag |= KANN_F_COST;
  69. a = (kann_t*)calloc(1, sizeof(kann_t));
  70. a->v = kad_compile_array(&a->n, n_roots, roots);
  71. for (i = 0; i < a->n; ++i) {
  72. if (a->v[i]->pre) has_recur = 1;
  73. if (kad_is_pivot(a->v[i])) has_pivot = 1;
  74. }
  75. if (has_recur && !has_pivot) { /* an RNN that doesn't have a pivot; then add a pivot on top of cost and recompile */
  76. cost->ext_flag &= ~KANN_F_COST;
  77. roots[n_roots-1] = cost = kad_avg(1, &cost), cost->ext_flag |= KANN_F_COST;
  78. free(a->v);
  79. a->v = kad_compile_array(&a->n, n_roots, roots);
  80. }
  81. kad_ext_collate(a->n, a->v, &a->x, &a->g, &a->c);
  82. free(roots);
  83. return a;
  84. }
  85. kann_t *kann_clone(kann_t *a, int batch_size)
  86. {
  87. kann_t *b;
  88. b = (kann_t*)calloc(1, sizeof(kann_t));
  89. b->n = a->n;
  90. b->v = kad_clone(a->n, a->v, batch_size);
  91. kad_ext_collate(b->n, b->v, &b->x, &b->g, &b->c);
  92. return b;
  93. }
  94. kann_t *kann_unroll_array(kann_t *a, int *len)
  95. {
  96. kann_t *b;
  97. b = (kann_t*)calloc(1, sizeof(kann_t));
  98. b->x = a->x, b->g = a->g, b->c = a->c; /* these arrays are shared */
  99. b->v = kad_unroll(a->n, a->v, &b->n, len);
  100. return b;
  101. }
  102. kann_t *kann_unroll(kann_t *a, ...)
  103. {
  104. kann_t *b;
  105. va_list ap;
  106. int i, n_pivots, *len;
  107. n_pivots = kad_n_pivots(a->n, a->v);
  108. len = (int*)calloc(n_pivots, sizeof(int));
  109. va_start(ap, a);
  110. for (i = 0; i < n_pivots; ++i) len[i] = va_arg(ap, int);
  111. va_end(ap);
  112. b = kann_unroll_array(a, len);
  113. free(len);
  114. return b;
  115. }
  116. void kann_delete_unrolled(kann_t *a)
  117. {
  118. if (a && a->mt) kann_mt(a, 0, 0);
  119. if (a && a->v) kad_delete(a->n, a->v);
  120. free(a);
  121. }
  122. void kann_delete(kann_t *a)
  123. {
  124. if (a == 0) return;
  125. free(a->x); free(a->g); free(a->c);
  126. kann_delete_unrolled(a);
  127. }
  128. static void kann_switch_core(kann_t *a, int is_train)
  129. {
  130. int i;
  131. for (i = 0; i < a->n; ++i)
  132. if (a->v[i]->op == 12 && a->v[i]->n_child == 2)
  133. *(int32_t*)a->v[i]->ptr = !!is_train;
  134. }
  135. #define chk_flg(flag, mask) ((mask) == 0 || ((flag) & (mask)))
  136. #define chk_lbl(label, query) ((query) == 0 || (label) == (query))
  137. int kann_find(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
  138. {
  139. int i, k, r = -1;
  140. for (i = k = 0; i < a->n; ++i)
  141. if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
  142. ++k, r = i;
  143. return k == 1? r : k == 0? -1 : -2;
  144. }
  145. int kann_feed_bind(kann_t *a, uint32_t ext_flag, int32_t ext_label, float **x)
  146. {
  147. int i, k;
  148. if (x == 0) return 0;
  149. for (i = k = 0; i < a->n; ++i)
  150. if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
  151. a->v[i]->x = x[k++];
  152. return k;
  153. }
  154. int kann_feed_dim(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
  155. {
  156. int i, k, n = 0;
  157. for (i = k = 0; i < a->n; ++i)
  158. if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
  159. ++k, n = a->v[i]->n_d > 1? kad_len(a->v[i]) / a->v[i]->d[0] : a->v[i]->n_d == 1? a->v[i]->d[0] : 1;
  160. return k == 1? n : k == 0? -1 : -2;
  161. }
  162. static float kann_cost_core(kann_t *a, int cost_label, int cal_grad)
  163. {
  164. int i_cost;
  165. float cost;
  166. i_cost = kann_find(a, KANN_F_COST, cost_label);
  167. assert(i_cost >= 0);
  168. cost = *kad_eval_at(a->n, a->v, i_cost);
  169. if (cal_grad) kad_grad(a->n, a->v, i_cost);
  170. return cost;
  171. }
  172. int kann_eval(kann_t *a, uint32_t ext_flag, int ext_label)
  173. {
  174. int i, k;
  175. for (i = k = 0; i < a->n; ++i)
  176. if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
  177. ++k, a->v[i]->tmp = 1;
  178. kad_eval_marked(a->n, a->v);
  179. return k;
  180. }
  181. void kann_rnn_start(kann_t *a)
  182. {
  183. int i;
  184. kann_set_batch_size(a, 1);
  185. for (i = 0; i < a->n; ++i) {
  186. kad_node_t *p = a->v[i];
  187. if (p->pre) { /* NB: BE CAREFUL of the interaction between kann_rnn_start() and kann_set_batch_size() */
  188. kad_node_t *q = p->pre;
  189. if (q->x) memcpy(p->x, q->x, kad_len(p) * sizeof(float));
  190. else memset(p->x, 0, kad_len(p) * sizeof(float));
  191. if (q->n_child > 0) free(q->x);
  192. q->x = p->x;
  193. }
  194. }
  195. }
  196. void kann_rnn_end(kann_t *a)
  197. {
  198. int i;
  199. kad_ext_sync(a->n, a->v, a->x, a->g, a->c);
  200. for (i = 0; i < a->n; ++i)
  201. if (a->v[i]->pre && a->v[i]->pre->n_child > 0)
  202. a->v[i]->pre->x = (float*)calloc(kad_len(a->v[i]->pre), sizeof(float));
  203. }
  204. static int kann_class_error_core(const kann_t *ann, int *base)
  205. {
  206. int i, j, k, m, n, off, n_err = 0;
  207. for (i = 0, *base = 0; i < ann->n; ++i) {
  208. kad_node_t *p = ann->v[i];
  209. if (((p->op == 13 && (p->n_child == 2 || p->n_child == 3)) || (p->op == 22 && p->n_child == 2)) && p->n_d == 0) { /* ce_bin or ce_multi */
  210. kad_node_t *x = p->child[0], *t = p->child[1];
  211. n = t->d[t->n_d - 1], m = kad_len(t) / n;
  212. for (j = off = 0; j < m; ++j, off += n) {
  213. float t_sum = 0.0f, t_min = 1.0f, t_max = 0.0f, x_max = 0.0f, x_min = 1.0f;
  214. int x_max_k = -1, t_max_k = -1;
  215. for (k = 0; k < n; ++k) {
  216. float xk = x->x[off+k], tk = t->x[off+k];
  217. t_sum += tk;
  218. t_min = t_min < tk? t_min : tk;
  219. x_min = x_min < xk? x_min : xk;
  220. if (t_max < tk) t_max = tk, t_max_k = k;
  221. if (x_max < xk) x_max = xk, x_max_k = k;
  222. }
  223. if (t_sum - 1.0f == 0 && t_min >= 0.0f && x_min >= 0.0f && x_max <= 1.0f) {
  224. ++(*base);
  225. n_err += (x_max_k != t_max_k);
  226. }
  227. }
  228. }
  229. }
  230. return n_err;
  231. }
  232. /*************************
  233. * @@MT: multi-threading *
  234. *************************/
  235. #ifdef HAVE_PTHREAD
  236. #include <pthread.h>
  237. struct mtaux_t;
  238. typedef struct { /* per-worker data */
  239. kann_t *a;
  240. float cost;
  241. int action;
  242. pthread_t tid;
  243. struct mtaux_t *g;
  244. } mtaux1_t;
  245. typedef struct mtaux_t { /* cross-worker data */
  246. int n_threads, max_batch_size;
  247. int cal_grad, cost_label, eval_out;
  248. volatile int n_idle; /* we will be busy waiting on this, so volatile necessary */
  249. pthread_mutex_t mtx;
  250. pthread_cond_t cv;
  251. mtaux1_t *mt;
  252. } mtaux_t;
  253. static void *mt_worker(void *data) /* pthread worker */
  254. {
  255. mtaux1_t *mt1 = (mtaux1_t*)data;
  256. mtaux_t *mt = mt1->g;
  257. for (;;) {
  258. int action;
  259. pthread_mutex_lock(&mt->mtx);
  260. mt1->action = 0;
  261. ++mt->n_idle;
  262. while (mt1->action == 0)
  263. pthread_cond_wait(&mt->cv, &mt->mtx);
  264. action = mt1->action;
  265. pthread_mutex_unlock(&mt->mtx);
  266. if (action == -1) break;
  267. if (mt->eval_out) kann_eval(mt1->a, KANN_F_OUT, 0);
  268. else mt1->cost = kann_cost_core(mt1->a, mt->cost_label, mt->cal_grad);
  269. }
  270. pthread_exit(0);
  271. }
  272. static void mt_destroy(mtaux_t *mt) /* de-allocate an entire mtaux_t struct */
  273. {
  274. int i;
  275. pthread_mutex_lock(&mt->mtx);
  276. mt->n_idle = 0;
  277. for (i = 1; i < mt->n_threads; ++i) mt->mt[i].action = -1;
  278. pthread_cond_broadcast(&mt->cv);
  279. pthread_mutex_unlock(&mt->mtx);
  280. for (i = 1; i < mt->n_threads; ++i) pthread_join(mt->mt[i].tid, 0);
  281. for (i = 0; i < mt->n_threads; ++i) kann_delete(mt->mt[i].a);
  282. free(mt->mt);
  283. pthread_cond_destroy(&mt->cv);
  284. pthread_mutex_destroy(&mt->mtx);
  285. free(mt);
  286. }
  287. void kann_mt(kann_t *ann, int n_threads, int max_batch_size)
  288. {
  289. mtaux_t *mt;
  290. int i, k;
  291. if (n_threads <= 1) {
  292. if (ann->mt) mt_destroy((mtaux_t*)ann->mt);
  293. ann->mt = 0;
  294. return;
  295. }
  296. if (n_threads > max_batch_size) n_threads = max_batch_size;
  297. if (n_threads <= 1) return;
  298. mt = (mtaux_t*)calloc(1, sizeof(mtaux_t));
  299. mt->n_threads = n_threads, mt->max_batch_size = max_batch_size;
  300. pthread_mutex_init(&mt->mtx, 0);
  301. pthread_cond_init(&mt->cv, 0);
  302. mt->mt = (mtaux1_t*)calloc(n_threads, sizeof(mtaux1_t));
  303. for (i = k = 0; i < n_threads; ++i) {
  304. int size = (max_batch_size - k) / (n_threads - i);
  305. mt->mt[i].a = kann_clone(ann, size);
  306. mt->mt[i].g = mt;
  307. k += size;
  308. }
  309. for (i = 1; i < n_threads; ++i)
  310. pthread_create(&mt->mt[i].tid, 0, mt_worker, &mt->mt[i]);
  311. while (mt->n_idle < n_threads - 1); /* busy waiting until all threads in sync */
  312. ann->mt = mt;
  313. }
  314. static void mt_kickoff(kann_t *a, int cost_label, int cal_grad, int eval_out)
  315. {
  316. mtaux_t *mt = (mtaux_t*)a->mt;
  317. int i, j, k, B, n_var;
  318. B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
  319. assert(B <= mt->max_batch_size); /* TODO: can be relaxed */
  320. n_var = kann_size_var(a);
  321. pthread_mutex_lock(&mt->mtx);
  322. mt->cost_label = cost_label, mt->cal_grad = cal_grad, mt->eval_out = eval_out;
  323. for (i = k = 0; i < mt->n_threads; ++i) {
  324. int size = (B - k) / (mt->n_threads - i);
  325. for (j = 0; j < a->n; ++j)
  326. if (kad_is_feed(a->v[j]))
  327. mt->mt[i].a->v[j]->x = &a->v[j]->x[k * kad_len(a->v[j]) / a->v[j]->d[0]];
  328. kad_sync_dim(mt->mt[i].a->n, mt->mt[i].a->v, size); /* TODO: we can point ->x to internal nodes, too */
  329. k += size;
  330. memcpy(mt->mt[i].a->x, a->x, n_var * sizeof(float));
  331. mt->mt[i].action = 1;
  332. }
  333. mt->n_idle = 0;
  334. pthread_cond_broadcast(&mt->cv);
  335. pthread_mutex_unlock(&mt->mtx);
  336. }
  337. float kann_cost(kann_t *a, int cost_label, int cal_grad)
  338. {
  339. mtaux_t *mt = (mtaux_t*)a->mt;
  340. int i, j, B, k, n_var;
  341. float cost;
  342. if (mt == 0) return kann_cost_core(a, cost_label, cal_grad);
  343. B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
  344. n_var = kann_size_var(a);
  345. mt_kickoff(a, cost_label, cal_grad, 0);
  346. mt->mt[0].cost = kann_cost_core(mt->mt[0].a, cost_label, cal_grad);
  347. while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
  348. memset(a->g, 0, n_var * sizeof(float)); /* TODO: check if this is necessary when cal_grad is false */
  349. for (i = k = 0, cost = 0.0f; i < mt->n_threads; ++i) {
  350. int size = (B - k) / (mt->n_threads - i);
  351. cost += mt->mt[i].cost * size / B;
  352. kad_saxpy(n_var, (float)size / B, mt->mt[i].a->g, a->g);
  353. k += size;
  354. }
  355. for (j = 0; j < a->n; ++j) { /* copy values back at recurrent nodes (needed by textgen; TODO: temporary solution) */
  356. kad_node_t *p = a->v[j];
  357. if (p->pre && p->n_d >= 2 && p->d[0] == B) {
  358. for (i = k = 0; i < mt->n_threads; ++i) {
  359. kad_node_t *q = mt->mt[i].a->v[j];
  360. memcpy(&p->x[k], q->x, kad_len(q) * sizeof(float));
  361. k += kad_len(q);
  362. }
  363. }
  364. }
  365. return cost;
  366. }
  367. int kann_eval_out(kann_t *a)
  368. {
  369. mtaux_t *mt = (mtaux_t*)a->mt;
  370. int j, B, n_eval;
  371. if (mt == 0) return kann_eval(a, KANN_F_OUT, 0);
  372. B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
  373. mt_kickoff(a, 0, 0, 1);
  374. n_eval = kann_eval(mt->mt[0].a, KANN_F_OUT, 0);
  375. while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
  376. for (j = 0; j < a->n; ++j) { /* copy output values back */
  377. kad_node_t *p = a->v[j];
  378. if (p->ext_flag & KANN_F_OUT) {
  379. int i, t, k, d0 = p->d[0] / B, d1 = 1; /* for RNN, p->d[0] may equal unroll_len * batch_size */
  380. assert(p->d[0] % B == 0);
  381. for (i = 1; i < p->n_d; ++i) d1 *= p->d[i];
  382. for (i = 0; i < d0; ++i) {
  383. for (t = k = 0; t < mt->n_threads; ++t) { /* similar to the forward pass of kad_op_concat() */
  384. kad_node_t *q = mt->mt[t].a->v[j];
  385. int size = q->d[0] / d0;
  386. memcpy(&p->x[(i * B + k) * d1], &q->x[i * size * d1], size * d1 * sizeof(float));
  387. k += size;
  388. }
  389. }
  390. }
  391. }
  392. return n_eval;
  393. }
  394. int kann_class_error(const kann_t *ann, int *base)
  395. {
  396. mtaux_t *mt = (mtaux_t*)ann->mt;
  397. int i, n_err = 0, b = 0;
  398. if (mt == 0) return kann_class_error_core(ann, base);
  399. for (i = 0; i < mt->n_threads; ++i) {
  400. n_err += kann_class_error_core(mt->mt[i].a, &b);
  401. *base += b;
  402. }
  403. return n_err;
  404. }
  405. void kann_switch(kann_t *ann, int is_train)
  406. {
  407. mtaux_t *mt = (mtaux_t*)ann->mt;
  408. int i;
  409. if (mt == 0) {
  410. kann_switch_core(ann, is_train);
  411. return;
  412. }
  413. for (i = 0; i < mt->n_threads; ++i)
  414. kann_switch_core(mt->mt[i].a, is_train);
  415. }
  416. #else
  417. void kann_mt(kann_t *ann, int n_threads, int max_batch_size) {}
  418. float kann_cost(kann_t *a, int cost_label, int cal_grad) { return kann_cost_core(a, cost_label, cal_grad); }
  419. int kann_eval_out(kann_t *a) { return kann_eval(a, KANN_F_OUT, 0); }
  420. int kann_class_error(const kann_t *a, int *base) { return kann_class_error_core(a, base); }
  421. void kann_switch(kann_t *ann, int is_train) { return kann_switch_core(ann, is_train); }
  422. #endif
  423. /***********************
  424. *** @@IO: model I/O ***
  425. ***********************/
  426. #define KANN_MAGIC "KAN\1"
  427. void kann_save_fp(FILE *fp, kann_t *ann)
  428. {
  429. kann_set_batch_size(ann, 1);
  430. fwrite(KANN_MAGIC, 1, 4, fp);
  431. kad_save(fp, ann->n, ann->v);
  432. fwrite(ann->x, sizeof(float), kann_size_var(ann), fp);
  433. fwrite(ann->c, sizeof(float), kann_size_const(ann), fp);
  434. }
  435. void kann_save(const char *fn, kann_t *ann)
  436. {
  437. FILE *fp;
  438. fp = fn && strcmp(fn, "-")? fopen(fn, "wb") : stdout;
  439. kann_save_fp(fp, ann);
  440. fclose(fp);
  441. }
  442. kann_t *kann_load_fp(FILE *fp)
  443. {
  444. char magic[4];
  445. kann_t *ann;
  446. int n_var, n_const;
  447. (void) !fread(magic, 1, 4, fp);
  448. if (strncmp(magic, KANN_MAGIC, 4) != 0) {
  449. return 0;
  450. }
  451. ann = (kann_t*)calloc(1, sizeof(kann_t));
  452. ann->v = kad_load(fp, &ann->n);
  453. n_var = kad_size_var(ann->n, ann->v);
  454. n_const = kad_size_const(ann->n, ann->v);
  455. ann->x = (float*)malloc(n_var * sizeof(float));
  456. ann->g = (float*)calloc(n_var, sizeof(float));
  457. ann->c = (float*)malloc(n_const * sizeof(float));
  458. (void) !fread(ann->x, sizeof(float), n_var, fp);
  459. (void) !fread(ann->c, sizeof(float), n_const, fp);
  460. kad_ext_sync(ann->n, ann->v, ann->x, ann->g, ann->c);
  461. return ann;
  462. }
  463. kann_t *kann_load(const char *fn)
  464. {
  465. FILE *fp;
  466. kann_t *ann;
  467. fp = fn && strcmp(fn, "-")? fopen(fn, "rb") : stdin;
  468. ann = kann_load_fp(fp);
  469. fclose(fp);
  470. return ann;
  471. }
  472. /**********************************************
  473. *** @@LAYER: layers and model generation ***
  474. **********************************************/
  475. /********** General but more complex APIs **********/
  476. kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM])
  477. {
  478. int i, len, off = offset && par? *offset : -1;
  479. kad_node_t *p;
  480. if (off >= 0 && par[off]) return par[(*offset)++];
  481. p = (kad_node_t*)calloc(1, sizeof(kad_node_t));
  482. p->n_d = n_d, p->flag = flag;
  483. memcpy(p->d, d, n_d * sizeof(int32_t));
  484. len = kad_len(p);
  485. p->x = (float*)calloc(len, sizeof(float));
  486. if (p->n_d <= 1) {
  487. for (i = 0; i < len; ++i)
  488. p->x[i] = x0_01;
  489. } else {
  490. double sdev_inv;
  491. sdev_inv = 1.0 / sqrt((double)len / p->d[0]);
  492. for (i = 0; i < len; ++i)
  493. p->x[i] = (float)(kad_drand_normal(0) * sdev_inv);
  494. }
  495. if (off >= 0) par[off] = p, ++(*offset);
  496. return p;
  497. }
  498. kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...)
  499. {
  500. int32_t i, d[KAD_MAX_DIM];
  501. va_list ap;
  502. va_start(ap, n_d); for (i = 0; i < n_d; ++i) d[i] = va_arg(ap, int); va_end(ap);
  503. return kann_new_leaf_array(offset, par, flag, x0_01, n_d, d);
  504. }
  505. kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1)
  506. {
  507. int n0;
  508. kad_node_t *w, *b;
  509. n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
  510. w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
  511. b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
  512. return kad_add(kad_cmul(in, w), b);
  513. }
  514. kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r)
  515. {
  516. kad_node_t *x[2], *cr;
  517. cr = kann_new_leaf2(offset, par, KAD_CONST, r, 0);
  518. x[0] = t, x[1] = kad_dropout(t, cr);
  519. return kad_switch(2, x);
  520. }
  521. kad_node_t *kann_layer_layernorm2(int *offset, kad_node_t **par, kad_node_t *in)
  522. {
  523. int n0;
  524. kad_node_t *alpha, *beta;
  525. n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
  526. alpha = kann_new_leaf2(offset, par, KAD_VAR, 1.0f, 1, n0);
  527. beta = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n0);
  528. return kad_add(kad_mul(kad_stdnorm(in), alpha), beta);
  529. }
  530. static inline kad_node_t *cmul_norm2(int *offset, kad_node_t **par, kad_node_t *x, kad_node_t *w, int use_norm)
  531. {
  532. return use_norm? kann_layer_layernorm2(offset, par, kad_cmul(x, w)) : kad_cmul(x, w);
  533. }
  534. kad_node_t *kann_layer_rnn2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag)
  535. {
  536. int n0, n1 = h0->d[h0->n_d-1], use_norm = !!(rnn_flag & KANN_RNN_NORM);
  537. kad_node_t *t, *w, *u, *b, *out;
  538. u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
  539. b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
  540. t = cmul_norm2(offset, par, h0, u, use_norm);
  541. if (in) {
  542. n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
  543. w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
  544. t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
  545. }
  546. out = kad_tanh(kad_add(t, b));
  547. out->pre = h0;
  548. return out;
  549. }
  550. kad_node_t *kann_layer_gru2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag)
  551. {
  552. int n0 = 0, n1 = h0->d[h0->n_d-1], use_norm = !!(rnn_flag & KANN_RNN_NORM);
  553. kad_node_t *t, *r, *z, *w, *u, *b, *s, *out;
  554. if (in) n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
  555. /* z = sigm(x_t * W_z + h_{t-1} * U_z + b_z) */
  556. u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
  557. b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
  558. t = cmul_norm2(offset, par, h0, u, use_norm);
  559. if (in) {
  560. w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
  561. t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
  562. }
  563. z = kad_sigm(kad_add(t, b));
  564. /* r = sigm(x_t * W_r + h_{t-1} * U_r + b_r) */
  565. u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
  566. b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
  567. t = cmul_norm2(offset, par, h0, u, use_norm);
  568. if (in) {
  569. w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
  570. t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
  571. }
  572. r = kad_sigm(kad_add(t, b));
  573. /* s = tanh(x_t * W_s + (h_{t-1} # r) * U_s + b_s) */
  574. u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
  575. b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
  576. t = cmul_norm2(offset, par, kad_mul(r, h0), u, use_norm);
  577. if (in) {
  578. w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
  579. t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
  580. }
  581. s = kad_tanh(kad_add(t, b));
  582. /* h_t = z # h_{t-1} + (1 - z) # s */
  583. out = kad_add(kad_mul(kad_1minus(z), s), kad_mul(z, h0));
  584. out->pre = h0;
  585. return out;
  586. }
  587. /********** APIs without offset & par **********/
  588. kad_node_t *kann_new_leaf(uint8_t flag, float x0_01, int n_d, ...)
  589. {
  590. int32_t i, d[KAD_MAX_DIM];
  591. va_list ap;
  592. va_start(ap, n_d); for (i = 0; i < n_d; ++i) d[i] = va_arg(ap, int); va_end(ap);
  593. return kann_new_leaf_array(0, 0, flag, x0_01, n_d, d);
  594. }
  595. kad_node_t *kann_new_scalar(uint8_t flag, float x) { return kann_new_leaf(flag, x, 0); }
  596. kad_node_t *kann_new_weight(int n_row, int n_col) { return kann_new_leaf(KAD_VAR, 0.0f, 2, n_row, n_col); }
  597. kad_node_t *kann_new_vec(int n, float x) { return kann_new_leaf(KAD_VAR, x, 1, n); }
  598. kad_node_t *kann_new_bias(int n) { return kann_new_vec(n, 0.0f); }
  599. kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col) { return kann_new_leaf(KAD_VAR, 0.0f, 4, n_out, n_in, k_row, k_col); }
  600. kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return kann_new_leaf(KAD_VAR, 0.0f, 3, n_out, n_in, kernel_len); }
  601. kad_node_t *kann_layer_input(int n1)
  602. {
  603. kad_node_t *t;
  604. t = kad_feed(2, 1, n1);
  605. t->ext_flag |= KANN_F_IN;
  606. return t;
  607. }
  608. kad_node_t *kann_layer_dense(kad_node_t *in, int n1) { return kann_layer_dense2(0, 0, in, n1); }
  609. kad_node_t *kann_layer_dropout(kad_node_t *t, float r) { return kann_layer_dropout2(0, 0, t, r); }
  610. kad_node_t *kann_layer_layernorm(kad_node_t *in) { return kann_layer_layernorm2(0, 0, in); }
  611. kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag)
  612. {
  613. kad_node_t *h0;
  614. h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
  615. h0->x = (float*)calloc(n1, sizeof(float));
  616. return kann_layer_rnn2(0, 0, in, h0, rnn_flag);
  617. }
  618. kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag)
  619. {
  620. kad_node_t *h0;
  621. h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
  622. h0->x = (float*)calloc(n1, sizeof(float));
  623. return kann_layer_gru2(0, 0, in, h0, rnn_flag);
  624. }
  625. static kad_node_t *kann_cmul_norm(kad_node_t *x, kad_node_t *w)
  626. {
  627. return kann_layer_layernorm(kad_cmul(x, w));
  628. }
  629. kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag)
  630. {
  631. int n0;
  632. kad_node_t *i, *f, *o, *g, *w, *u, *b, *h0, *c0, *c, *out;
  633. kad_node_t *(*cmul)(kad_node_t*, kad_node_t*) = (rnn_flag & KANN_RNN_NORM)? kann_cmul_norm : kad_cmul;
  634. n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
  635. h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
  636. h0->x = (float*)calloc(n1, sizeof(float));
  637. c0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
  638. c0->x = (float*)calloc(n1, sizeof(float));
  639. /* i = sigm(x_t * W_i + h_{t-1} * U_i + b_i) */
  640. w = kann_new_weight(n1, n0);
  641. u = kann_new_weight(n1, n1);
  642. b = kann_new_bias(n1);
  643. i = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
  644. /* f = sigm(x_t * W_f + h_{t-1} * U_f + b_f) */
  645. w = kann_new_weight(n1, n0);
  646. u = kann_new_weight(n1, n1);
  647. b = kann_new_vec(n1, 1.0f); /* see Jozefowicz et al on using a large bias */
  648. f = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
  649. /* o = sigm(x_t * W_o + h_{t-1} * U_o + b_o) */
  650. w = kann_new_weight(n1, n0);
  651. u = kann_new_weight(n1, n1);
  652. b = kann_new_bias(n1);
  653. o = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
  654. /* g = tanh(x_t * W_g + h_{t-1} * U_g + b_g) */
  655. w = kann_new_weight(n1, n0);
  656. u = kann_new_weight(n1, n1);
  657. b = kann_new_bias(n1);
  658. g = kad_tanh(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
  659. /* c_t = c_{t-1} # f + g # i */
  660. c = kad_add(kad_mul(f, c0), kad_mul(g, i)); /* can't be kad_mul(c0, f)!!! */
  661. c->pre = c0;
  662. /* h_t = tanh(c_t) # o */
  663. if (rnn_flag & KANN_RNN_NORM) c = kann_layer_layernorm(c); /* see Ba et al (2016) about how to apply layer normalization to LSTM */
  664. out = kad_mul(kad_tanh(c), o);
  665. out->pre = h0;
  666. return out;
  667. }
  668. kad_node_t *kann_layer_conv2d(kad_node_t *in, int n_flt, int k_rows, int k_cols, int stride_r, int stride_c, int pad_r, int pad_c)
  669. {
  670. kad_node_t *w;
  671. w = kann_new_weight_conv2d(n_flt, in->d[1], k_rows, k_cols);
  672. return kad_conv2d(in, w, stride_r, stride_c, pad_r, pad_c);
  673. }
  674. kad_node_t *kann_layer_conv1d(kad_node_t *in, int n_flt, int k_size, int stride, int pad)
  675. {
  676. kad_node_t *w;
  677. w = kann_new_weight_conv1d(n_flt, in->d[1], k_size);
  678. return kad_conv1d(in, w, stride, pad);
  679. }
  680. kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
  681. {
  682. kad_node_t *cost = 0, *truth = 0;
  683. assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
  684. t = kann_layer_dense(t, n_out);
  685. truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
  686. if (cost_type == KANN_C_MSE) {
  687. cost = kad_mse(t, truth);
  688. } else if (cost_type == KANN_C_CEB) {
  689. t = kad_sigm(t);
  690. cost = kad_ce_bin(t, truth);
  691. } else if (cost_type == KANN_C_CEB_NEG) {
  692. t = kad_tanh(t);
  693. cost = kad_ce_bin_neg(t, truth);
  694. } else if (cost_type == KANN_C_CEM) {
  695. t = kad_softmax(t);
  696. cost = kad_ce_multi(t, truth);
  697. }
  698. else {
  699. assert (0);
  700. }
  701. t->ext_flag |= KANN_F_OUT;
  702. cost->ext_flag |= KANN_F_COST;
  703. return cost;
  704. }
  705. void kann_shuffle(int n, int *s)
  706. {
  707. int i, j, t;
  708. for (i = 0; i < n; ++i) s[i] = i;
  709. for (i = n; i > 0; --i) {
  710. j = (int)(i * kad_drand(0));
  711. t = s[j], s[j] = s[i-1], s[i-1] = t;
  712. }
  713. }
  714. /***************************
  715. *** @@MIN: minimization ***
  716. ***************************/
  717. #ifdef __SSE__
  718. #include <xmmintrin.h>
  719. void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g, float *t, float *r)
  720. {
  721. int i, n4 = n>>2<<2;
  722. __m128 vh, vg, vr, vt, vd, vd1, tmp, vtiny;
  723. vh = _mm_set1_ps(h0);
  724. vd = _mm_set1_ps(decay);
  725. vd1 = _mm_set1_ps(1.0f - decay);
  726. vtiny = _mm_set1_ps(1e-6f);
  727. for (i = 0; i < n4; i += 4) {
  728. vt = _mm_loadu_ps(&t[i]);
  729. vr = _mm_loadu_ps(&r[i]);
  730. vg = _mm_loadu_ps(&g[i]);
  731. if (h) vh = _mm_loadu_ps(&h[i]);
  732. vr = _mm_add_ps(_mm_mul_ps(vd1, _mm_mul_ps(vg, vg)), _mm_mul_ps(vd, vr));
  733. _mm_storeu_ps(&r[i], vr);
  734. tmp = _mm_sub_ps(vt, _mm_mul_ps(_mm_mul_ps(vh, _mm_rsqrt_ps(_mm_add_ps(vtiny, vr))), vg));
  735. _mm_storeu_ps(&t[i], tmp);
  736. }
  737. for (; i < n; ++i) {
  738. r[i] = (1. - decay) * g[i] * g[i] + decay * r[i];
  739. t[i] -= (h? h[i] : h0) / sqrtf(1e-6f + r[i]) * g[i];
  740. }
  741. }
  742. #else
  743. void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g, float *t, float *r)
  744. {
  745. int i;
  746. for (i = 0; i < n; ++i) {
  747. float lr = h? h[i] : h0;
  748. r[i] = (1.0f - decay) * g[i] * g[i] + decay * r[i];
  749. t[i] -= lr / sqrtf(1e-6f + r[i]) * g[i];
  750. }
  751. }
  752. #endif
  753. float kann_grad_clip(float thres, int n, float *g)
  754. {
  755. int i;
  756. double s2 = 0.0;
  757. for (i = 0; i < n; ++i)
  758. s2 += g[i] * g[i];
  759. s2 = sqrt(s2);
  760. if (s2 > thres)
  761. for (i = 0, s2 = 1.0 / s2; i < n; ++i)
  762. g[i] *= (float)s2;
  763. return (float)s2 / thres;
  764. }
  765. /****************************************************************
  766. *** @@XY: simpler API for network with a single input/output ***
  767. ****************************************************************/
  768. int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch,
  769. int max_drop_streak, float frac_val, int n,
  770. float **_x, float **_y,
  771. kann_train_cb cb, void *ud)
  772. {
  773. int i, j, *shuf, n_train, n_val, n_in, n_out, n_var, n_const, drop_streak = 0, min_set = 0;
  774. float **x, **y, *x1, *y1, *r, min_val_cost = FLT_MAX, *min_x, *min_c;
  775. n_in = kann_dim_in(ann);
  776. n_out = kann_dim_out(ann);
  777. if (n_in < 0 || n_out < 0) return -1;
  778. n_var = kann_size_var(ann);
  779. n_const = kann_size_const(ann);
  780. r = (float*)calloc(n_var, sizeof(float));
  781. shuf = (int*)malloc(n * sizeof(int));
  782. x = (float**)malloc(n * sizeof(float*));
  783. y = (float**)malloc(n * sizeof(float*));
  784. kann_shuffle(n, shuf);
  785. for (j = 0; j < n; ++j)
  786. x[j] = _x[shuf[j]], y[j] = _y[shuf[j]];
  787. n_val = (int)(n * frac_val);
  788. n_train = n - n_val;
  789. min_x = (float*)malloc(n_var * sizeof(float));
  790. min_c = (float*)malloc(n_const * sizeof(float));
  791. x1 = (float*)malloc(n_in * mini_size * sizeof(float));
  792. y1 = (float*)malloc(n_out * mini_size * sizeof(float));
  793. kann_feed_bind(ann, KANN_F_IN, 0, &x1);
  794. kann_feed_bind(ann, KANN_F_TRUTH, 0, &y1);
  795. for (i = 0; i < max_epoch; ++i) {
  796. int n_proc = 0, n_train_err = 0, n_val_err = 0, n_train_base = 0, n_val_base = 0;
  797. double train_cost = 0.0, val_cost = 0.0;
  798. kann_shuffle(n_train, shuf);
  799. kann_switch(ann, 1);
  800. while (n_proc < n_train) {
  801. int b, c, ms = n_train - n_proc < mini_size? n_train - n_proc : mini_size;
  802. for (b = 0; b < ms; ++b) {
  803. memcpy(&x1[b*n_in], x[shuf[n_proc+b]], n_in * sizeof(float));
  804. memcpy(&y1[b*n_out], y[shuf[n_proc+b]], n_out * sizeof(float));
  805. }
  806. kann_set_batch_size(ann, ms);
  807. train_cost += kann_cost(ann, 0, 1) * ms;
  808. c = kann_class_error(ann, &b);
  809. n_train_err += c, n_train_base += b;
  810. kann_RMSprop(n_var, lr, 0, 0.9f, ann->g, ann->x, r);
  811. n_proc += ms;
  812. }
  813. train_cost /= n_train;
  814. kann_switch(ann, 0);
  815. n_proc = 0;
  816. while (n_proc < n_val) {
  817. int b, c, ms = n_val - n_proc < mini_size? n_val - n_proc : mini_size;
  818. for (b = 0; b < ms; ++b) {
  819. memcpy(&x1[b*n_in], x[n_train+n_proc+b], n_in * sizeof(float));
  820. memcpy(&y1[b*n_out], y[n_train+n_proc+b], n_out * sizeof(float));
  821. }
  822. kann_set_batch_size(ann, ms);
  823. val_cost += kann_cost(ann, 0, 0) * ms;
  824. c = kann_class_error(ann, &b);
  825. n_val_err += c, n_val_base += b;
  826. n_proc += ms;
  827. }
  828. if (n_val > 0) val_cost /= n_val;
  829. if (cb) {
  830. cb(i + 1, train_cost, val_cost, ud);
  831. #if 0
  832. fprintf(stderr, "epoch: %d; training cost: %g", i+1, train_cost);
  833. if (n_train_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_train_err / n_train);
  834. if (n_val > 0) {
  835. fprintf(stderr, "; validation cost: %g", val_cost);
  836. if (n_val_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_val_err / n_val);
  837. }
  838. fputc('\n', stderr);
  839. #endif
  840. }
  841. if (i >= max_drop_streak && n_val > 0) {
  842. if (val_cost < min_val_cost) {
  843. min_set = 1;
  844. memcpy(min_x, ann->x, n_var * sizeof(float));
  845. memcpy(min_c, ann->c, n_const * sizeof(float));
  846. drop_streak = 0;
  847. min_val_cost = (float)val_cost;
  848. } else if (++drop_streak >= max_drop_streak)
  849. break;
  850. }
  851. }
  852. if (min_set) {
  853. memcpy(ann->x, min_x, n_var * sizeof(float));
  854. memcpy(ann->c, min_c, n_const * sizeof(float));
  855. }
  856. free(min_c); free(min_x); free(y1); free(x1); free(y); free(x); free(shuf); free(r);
  857. return i;
  858. }
  859. float kann_cost_fnn1(kann_t *ann, int n, float **x, float **y)
  860. {
  861. int n_in, n_out, n_proc = 0, mini_size = 64 < n? 64 : n;
  862. float *x1, *y1;
  863. double cost = 0.0;
  864. n_in = kann_dim_in(ann);
  865. n_out = kann_dim_out(ann);
  866. if (n <= 0 || n_in < 0 || n_out < 0) return 0.0;
  867. x1 = (float*)malloc(n_in * mini_size * sizeof(float));
  868. y1 = (float*)malloc(n_out * mini_size * sizeof(float));
  869. kann_feed_bind(ann, KANN_F_IN, 0, &x1);
  870. kann_feed_bind(ann, KANN_F_TRUTH, 0, &y1);
  871. kann_switch(ann, 0);
  872. while (n_proc < n) {
  873. int b, ms = n - n_proc < mini_size? n - n_proc : mini_size;
  874. for (b = 0; b < ms; ++b) {
  875. memcpy(&x1[b*n_in], x[n_proc+b], n_in * sizeof(float));
  876. memcpy(&y1[b*n_out], y[n_proc+b], n_out * sizeof(float));
  877. }
  878. kann_set_batch_size(ann, ms);
  879. cost += kann_cost(ann, 0, 0) * ms;
  880. n_proc += ms;
  881. }
  882. free(y1); free(x1);
  883. return (float)(cost / n);
  884. }
  885. const float *kann_apply1(kann_t *a, float *x)
  886. {
  887. int i_out;
  888. i_out = kann_find(a, KANN_F_OUT, 0);
  889. if (i_out < 0) return 0;
  890. kann_set_batch_size(a, 1);
  891. kann_feed_bind(a, KANN_F_IN, 0, &x);
  892. kad_eval_at(a->n, a->v, i_out);
  893. return a->v[i_out]->x;
  894. }