You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_kann.c 31KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336
  1. /*-
  2. * Copyright 2019 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "lua_tensor.h"
  18. #include "contrib/kann/kann.h"
  19. /***
  20. * @module rspamd_kann
  21. * `rspamd_kann` is a Lua interface to kann library
  22. */
  23. #define KANN_NODE_CLASS "rspamd{kann_node}"
  24. #define KANN_NETWORK_CLASS "rspamd{kann}"
  25. /* Simple macros to define behaviour */
  26. #define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
  27. #define KANN_LAYER_INTERFACE(name) {#name, lua_kann_layer_ ## name}
  28. #define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_ ## name (lua_State *L)
  29. #define KANN_TRANSFORM_INTERFACE(name) {#name, lua_kann_transform_ ## name}
  30. #define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
  31. #define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
  32. #define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
  33. #define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}
  34. /*
  35. * Forwarded declarations
  36. */
  37. static kad_node_t *lua_check_kann_node (lua_State *L, int pos);
  38. /* Layers */
  39. KANN_LAYER_DEF(input);
  40. KANN_LAYER_DEF(dense);
  41. KANN_LAYER_DEF(layernorm);
  42. KANN_LAYER_DEF(rnn);
  43. KANN_LAYER_DEF(lstm);
  44. KANN_LAYER_DEF(gru);
  45. KANN_LAYER_DEF(conv2d);
  46. KANN_LAYER_DEF(conv1d);
  47. KANN_LAYER_DEF(cost);
  48. static luaL_reg rspamd_kann_layers_f[] = {
  49. KANN_LAYER_INTERFACE(input),
  50. KANN_LAYER_INTERFACE(dense),
  51. KANN_LAYER_INTERFACE(layernorm),
  52. KANN_LAYER_INTERFACE(rnn),
  53. KANN_LAYER_INTERFACE(lstm),
  54. KANN_LAYER_INTERFACE(gru),
  55. KANN_LAYER_INTERFACE(conv2d),
  56. KANN_LAYER_INTERFACE(conv1d),
  57. KANN_LAYER_INTERFACE(cost),
  58. {NULL, NULL},
  59. };
  60. /* Transition and composition functions */
  61. /* General transform */
  62. KANN_TRANSFORM_DEF (add);
  63. KANN_TRANSFORM_DEF (sub);
  64. KANN_TRANSFORM_DEF (mul);
  65. KANN_TRANSFORM_DEF (cmul);
  66. KANN_TRANSFORM_DEF (matmul);
  67. KANN_TRANSFORM_DEF (square);
  68. KANN_TRANSFORM_DEF (sigm);
  69. KANN_TRANSFORM_DEF (tanh);
  70. KANN_TRANSFORM_DEF (relu);
  71. KANN_TRANSFORM_DEF (softmax);
  72. KANN_TRANSFORM_DEF (1minus);
  73. KANN_TRANSFORM_DEF (exp);
  74. KANN_TRANSFORM_DEF (log);
  75. KANN_TRANSFORM_DEF (sin);
  76. static luaL_reg rspamd_kann_transform_f[] = {
  77. KANN_TRANSFORM_INTERFACE (add),
  78. KANN_TRANSFORM_INTERFACE (sub),
  79. KANN_TRANSFORM_INTERFACE (mul),
  80. KANN_TRANSFORM_INTERFACE (cmul),
  81. KANN_TRANSFORM_INTERFACE (matmul),
  82. KANN_TRANSFORM_INTERFACE (square),
  83. KANN_TRANSFORM_INTERFACE (sigm),
  84. KANN_TRANSFORM_INTERFACE (tanh),
  85. KANN_TRANSFORM_INTERFACE (relu),
  86. KANN_TRANSFORM_INTERFACE (softmax),
  87. KANN_TRANSFORM_INTERFACE (1minus),
  88. KANN_TRANSFORM_INTERFACE (exp),
  89. KANN_TRANSFORM_INTERFACE (log),
  90. KANN_TRANSFORM_INTERFACE (sin),
  91. {NULL, NULL},
  92. };
  93. /* Loss functions */
  94. KANN_LOSS_DEF (mse);
  95. KANN_LOSS_DEF (ce_multi);
  96. KANN_LOSS_DEF (ce_bin);
  97. KANN_LOSS_DEF (ce_bin_neg);
  98. KANN_LOSS_DEF (ce_multi_weighted);
  99. static luaL_reg rspamd_kann_loss_f[] = {
  100. KANN_LOSS_INTERFACE (mse),
  101. KANN_LOSS_INTERFACE (ce_multi),
  102. KANN_LOSS_INTERFACE (ce_bin),
  103. KANN_LOSS_INTERFACE (ce_bin_neg),
  104. KANN_LOSS_INTERFACE (ce_multi_weighted),
  105. {NULL, NULL},
  106. };
  107. /* Creation functions */
  108. KANN_NEW_DEF (leaf);
  109. KANN_NEW_DEF (scalar);
  110. KANN_NEW_DEF (weight);
  111. KANN_NEW_DEF (bias);
  112. KANN_NEW_DEF (weight_conv2d);
  113. KANN_NEW_DEF (weight_conv1d);
  114. KANN_NEW_DEF (kann);
  115. static luaL_reg rspamd_kann_new_f[] = {
  116. KANN_NEW_INTERFACE (leaf),
  117. KANN_NEW_INTERFACE (scalar),
  118. KANN_NEW_INTERFACE (weight),
  119. KANN_NEW_INTERFACE (bias),
  120. KANN_NEW_INTERFACE (weight_conv2d),
  121. KANN_NEW_INTERFACE (weight_conv1d),
  122. KANN_NEW_INTERFACE (kann),
  123. {NULL, NULL},
  124. };
  125. LUA_FUNCTION_DEF (kann, load);
  126. LUA_FUNCTION_DEF (kann, destroy);
  127. LUA_FUNCTION_DEF (kann, save);
  128. LUA_FUNCTION_DEF (kann, train1);
  129. LUA_FUNCTION_DEF (kann, apply1);
  130. static luaL_reg rspamd_kann_m[] = {
  131. LUA_INTERFACE_DEF (kann, save),
  132. LUA_INTERFACE_DEF (kann, train1),
  133. LUA_INTERFACE_DEF (kann, apply1),
  134. {"__gc", lua_kann_destroy},
  135. {NULL, NULL},
  136. };
  137. static int
  138. rspamd_kann_table_to_flags (lua_State *L, int table_pos)
  139. {
  140. int result = 0;
  141. lua_pushvalue (L, table_pos);
  142. for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
  143. int fl = lua_tointeger (L, -1);
  144. result |= fl;
  145. }
  146. lua_pop (L, 1);
  147. return result;
  148. }
  149. static gint
  150. lua_load_kann (lua_State * L)
  151. {
  152. lua_newtable (L);
  153. /* Flags */
  154. lua_pushstring (L, "flag");
  155. lua_newtable (L);
  156. lua_pushinteger (L, KANN_F_IN);
  157. lua_setfield (L, -2, "in");
  158. lua_pushinteger (L, KANN_F_COST);
  159. lua_setfield (L, -2, "cost");
  160. lua_pushinteger (L, KANN_F_OUT);
  161. lua_setfield (L, -2, "out");
  162. lua_pushinteger (L, KANN_F_TRUTH);
  163. lua_setfield (L, -2, "truth");
  164. lua_settable (L, -3);
  165. /* Cost type */
  166. lua_pushstring (L, "cost");
  167. lua_newtable (L);
  168. /* binary cross-entropy cost, used with sigmoid */
  169. lua_pushinteger (L, KANN_C_CEB);
  170. lua_setfield (L, -2, "ceb");
  171. /* multi-class cross-entropy cost, used with softmax */
  172. lua_pushinteger (L, KANN_C_CEM);
  173. lua_setfield (L, -2, "cem");
  174. /* binary cross-entropy-like cost, used with tanh */
  175. lua_pushinteger (L, KANN_C_CEB_NEG);
  176. lua_setfield (L, -2, "ceb_neg");
  177. lua_pushinteger (L, KANN_C_MSE);
  178. lua_setfield (L, -2, "mse");
  179. lua_settable (L, -3);
  180. /* RNN flag */
  181. lua_pushstring (L, "rnn");
  182. lua_newtable (L);
  183. /* apply layer normalization */
  184. lua_pushinteger (L, KANN_RNN_NORM);
  185. lua_setfield (L, -2, "norm");
  186. /* take the initial hidden values as variables */
  187. lua_pushinteger (L, KANN_RNN_VAR_H0);
  188. lua_setfield (L, -2, "var_h0");
  189. lua_settable (L, -3);
  190. /* Layers */
  191. lua_pushstring (L, "layer");
  192. lua_newtable (L);
  193. luaL_register (L, NULL, rspamd_kann_layers_f);
  194. lua_settable (L, -3);
  195. /* Transforms */
  196. lua_pushstring (L, "transform");
  197. lua_newtable (L);
  198. luaL_register (L, NULL, rspamd_kann_transform_f);
  199. lua_settable (L, -3);
  200. /* Cost */
  201. lua_pushstring (L, "loss");
  202. lua_newtable (L);
  203. luaL_register (L, NULL, rspamd_kann_loss_f);
  204. lua_settable (L, -3);
  205. /* Create functions */
  206. lua_pushstring (L, "new");
  207. lua_newtable (L);
  208. luaL_register (L, NULL, rspamd_kann_new_f);
  209. lua_settable (L, -3);
  210. /* Load ann from memory or file */
  211. lua_pushstring (L, "load");
  212. lua_pushcfunction (L, lua_kann_load);
  213. lua_settable (L, -3);
  214. return 1;
  215. }
  216. static kad_node_t *
  217. lua_check_kann_node (lua_State *L, int pos)
  218. {
  219. void *ud = rspamd_lua_check_udata (L, pos, KANN_NODE_CLASS);
  220. luaL_argcheck (L, ud != NULL, pos, "'kann_node' expected");
  221. return ud ? *((kad_node_t **)ud) : NULL;
  222. }
  223. static kann_t *
  224. lua_check_kann (lua_State *L, int pos)
  225. {
  226. void *ud = rspamd_lua_check_udata (L, pos, KANN_NETWORK_CLASS);
  227. luaL_argcheck (L, ud != NULL, pos, "'kann' expected");
  228. return ud ? *((kann_t **)ud) : NULL;
  229. }
  230. void luaopen_kann (lua_State *L)
  231. {
  232. /* Metatables */
  233. rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
  234. lua_pop (L, 1); /* No need in metatable... */
  235. rspamd_lua_new_class (L, KANN_NETWORK_CLASS, rspamd_kann_m);
  236. lua_pop (L, 1); /* No need in metatable... */
  237. rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
  238. lua_settop (L, 0);
  239. }
  240. /* Layers implementation */
  241. #define PUSH_KAD_NODE(n) do { \
  242. kad_node_t **pt; \
  243. pt = lua_newuserdata (L, sizeof (kad_node_t *)); \
  244. *pt = (n); \
  245. rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
  246. } while(0)
  247. #define PUSH_KAN_NETWORK(n) do { \
  248. kann_t **pn; \
  249. pn = lua_newuserdata (L, sizeof (kann_t *)); \
  250. *pn = (n); \
  251. rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
  252. } while(0)
  253. #define PROCESS_KAD_FLAGS(n, pos) do { \
  254. int fl = 0; \
  255. if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
  256. else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
  257. (n)->ext_flag |= fl; \
  258. }while(0)
  259. /***
  260. * @function kann.layer.input(ninputs[, flags])
  261. * Creates an input layer for ANN
  262. * @param {int} ninputs number of inputs
  263. * @param {table|int} flags optional flags
  264. * @return {kann_node} kann node object (should be used to combine ANN)
  265. */
  266. static int
  267. lua_kann_layer_input (lua_State *L)
  268. {
  269. gint nnodes = luaL_checkinteger (L, 1);
  270. if (nnodes > 0) {
  271. kad_node_t *t;
  272. t = kann_layer_input (nnodes);
  273. PROCESS_KAD_FLAGS (t, 2);
  274. PUSH_KAD_NODE (t);
  275. }
  276. else {
  277. return luaL_error (L, "invalid arguments, nnodes required");
  278. }
  279. return 1;
  280. }
  281. /***
  282. * @function kann.layer.dense(in, ninputs[, flags])
  283. * Creates a dense layer (e.g. for hidden layer)
  284. * @param {kann_node} in kann node
  285. * @param {int} ninputs number of dense nodes
  286. * @param {table|int} flags optional flags
  287. * @return {kann_node} kann node object (should be used to combine ANN)
  288. */
  289. static int
  290. lua_kann_layer_dense (lua_State *L)
  291. {
  292. kad_node_t *in = lua_check_kann_node (L, 1);
  293. gint nnodes = luaL_checkinteger (L, 2);
  294. if (in != NULL && nnodes > 0) {
  295. kad_node_t *t;
  296. t = kann_layer_dense (in, nnodes);
  297. PROCESS_KAD_FLAGS (t, 3);
  298. PUSH_KAD_NODE (t);
  299. }
  300. else {
  301. return luaL_error (L, "invalid arguments, input + nnodes required");
  302. }
  303. return 1;
  304. }
  305. /***
  306. * @function kann.layer.dropout(in, ratio[, flags])
  307. * Creates a dropout layer
  308. * @param {kann_node} in kann node
  309. * @param {float} ratio drop ratio
  310. * @param {table|int} flags optional flags
  311. * @return {kann_node} kann node object (should be used to combine ANN)
  312. */
  313. static int
  314. lua_kann_layer_layerdropout (lua_State *L)
  315. {
  316. kad_node_t *in = lua_check_kann_node (L, 1);
  317. double r = luaL_checknumber (L, 2);
  318. if (in != NULL) {
  319. kad_node_t *t;
  320. t = kann_layer_dropout (in, r);
  321. PROCESS_KAD_FLAGS (t, 3);
  322. PUSH_KAD_NODE (t);
  323. }
  324. else {
  325. return luaL_error (L, "invalid arguments, input + rate required");
  326. }
  327. return 1;
  328. }
  329. /***
  330. * @function kann.layer.dropout(in [, flags])
  331. * Creates a normalisation layer
  332. * @param {kann_node} in kann node
  333. * @param {table|int} flags optional flags
  334. * @return {kann_node} kann node object (should be used to combine ANN)
  335. */
  336. static int
  337. lua_kann_layer_layernorm (lua_State *L)
  338. {
  339. kad_node_t *in = lua_check_kann_node (L, 1);
  340. if (in != NULL) {
  341. kad_node_t *t;
  342. t = kann_layer_layernorm (in);
  343. PROCESS_KAD_FLAGS (t, 2);
  344. PUSH_KAD_NODE (t);
  345. }
  346. else {
  347. return luaL_error (L, "invalid arguments, input required");
  348. }
  349. return 1;
  350. }
  351. /***
  352. * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
  353. * Creates a recursive NN layer
  354. * @param {kann_node} in kann node
  355. * @param {int} nnodes number of cells
  356. * @param {int} rnnflags rnn flags
  357. * @param {table|int} flags optional flags
  358. * @return {kann_node} kann node object (should be used to combine ANN)
  359. */
  360. static int
  361. lua_kann_layer_rnn (lua_State *L)
  362. {
  363. kad_node_t *in = lua_check_kann_node (L, 1);
  364. gint nnodes = luaL_checkinteger (L, 2);
  365. gint rnnflags = 0;
  366. if (in != NULL && nnodes > 0) {
  367. kad_node_t *t;
  368. if (lua_type (L, 3) == LUA_TNUMBER) {
  369. rnnflags = lua_tointeger (L, 3);
  370. }
  371. t = kann_layer_rnn (in, nnodes, rnnflags);
  372. PROCESS_KAD_FLAGS (t, 4);
  373. PUSH_KAD_NODE (t);
  374. }
  375. else {
  376. return luaL_error (L, "invalid arguments, input + nnodes required");
  377. }
  378. return 1;
  379. }
  380. /***
  381. * @function kann.layer.lstm(in, nnodes[, rnn_flags, [, flags]])
  382. * Creates a recursive NN layer using LSTM cells
  383. * @param {kann_node} in kann node
  384. * @param {int} nnodes number of cells
  385. * @param {int} rnnflags rnn flags
  386. * @param {table|int} flags optional flags
  387. * @return {kann_node} kann node object (should be used to combine ANN)
  388. */
  389. static int
  390. lua_kann_layer_lstm (lua_State *L)
  391. {
  392. kad_node_t *in = lua_check_kann_node (L, 1);
  393. gint nnodes = luaL_checkinteger (L, 2);
  394. gint rnnflags = 0;
  395. if (in != NULL && nnodes > 0) {
  396. kad_node_t *t;
  397. if (lua_type (L, 3) == LUA_TNUMBER) {
  398. rnnflags = lua_tointeger (L, 3);
  399. }
  400. t = kann_layer_lstm (in, nnodes, rnnflags);
  401. PROCESS_KAD_FLAGS (t, 4);
  402. PUSH_KAD_NODE (t);
  403. }
  404. else {
  405. return luaL_error (L, "invalid arguments, input + nnodes required");
  406. }
  407. return 1;
  408. }
  409. /***
  410. * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
  411. * Creates a recursive NN layer using GRU cells
  412. * @param {kann_node} in kann node
  413. * @param {int} nnodes number of cells
  414. * @param {int} rnnflags rnn flags
  415. * @param {table|int} flags optional flags
  416. * @return {kann_node} kann node object (should be used to combine ANN)
  417. */
  418. static int
  419. lua_kann_layer_gru (lua_State *L)
  420. {
  421. kad_node_t *in = lua_check_kann_node (L, 1);
  422. gint nnodes = luaL_checkinteger (L, 2);
  423. gint rnnflags = 0;
  424. if (in != NULL && nnodes > 0) {
  425. kad_node_t *t;
  426. if (lua_type (L, 3) == LUA_TNUMBER) {
  427. rnnflags = lua_tointeger (L, 3);
  428. }
  429. t = kann_layer_gru (in, nnodes, rnnflags);
  430. PROCESS_KAD_FLAGS (t, 4);
  431. PUSH_KAD_NODE (t);
  432. }
  433. else {
  434. return luaL_error (L, "invalid arguments, input + nnodes required");
  435. }
  436. return 1;
  437. }
  438. /***
  439. * @function kann.layer.conv2d(in, n_flt, k_rows, k_cols, stride_rows, stride_cols, pad_rows, pad_columns[, flags])
  440. * Creates a 2D convolution layer
  441. * @param {kann_node} in kann node
  442. * @param {int} n_flt number of filters
  443. * @param {int} k_rows kernel rows
  444. * @param {int} k_cols kernel columns
  445. * @param {int} stride_rows stride rows
  446. * @param {int} stride_cols stride columns
  447. * @param {int} pad_rows padding rows
  448. * @param {int} pad_columns padding columns
  449. * @param {table|int} flags optional flags
  450. * @return {kann_node} kann node object (should be used to combine ANN)
  451. */
  452. static int
  453. lua_kann_layer_conv2d (lua_State *L)
  454. {
  455. kad_node_t *in = lua_check_kann_node (L, 1);
  456. int n_flt = luaL_checkinteger (L, 2);
  457. int k_rows = luaL_checkinteger (L, 3);
  458. int k_cols = luaL_checkinteger (L, 4);
  459. int stride_r = luaL_checkinteger (L, 5);
  460. int stride_c = luaL_checkinteger (L, 6);
  461. int pad_r = luaL_checkinteger (L, 7);
  462. int pad_c = luaL_checkinteger (L, 8);
  463. if (in != NULL) {
  464. kad_node_t *t;
  465. t = kann_layer_conv2d (in, n_flt, k_rows, k_cols, stride_r, stride_c,
  466. pad_r, pad_c);
  467. PROCESS_KAD_FLAGS (t, 9);
  468. PUSH_KAD_NODE (t);
  469. }
  470. else {
  471. return luaL_error (L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
  472. }
  473. return 1;
  474. }
  475. /***
  476. * @function kann.layer.conv1d(in, n_flt, kern_size, stride_size, pad_size[, flags])
  477. * Creates 1D convolution layer
  478. * @param {kann_node} in kann node
  479. * @param {int} n_flt number of filters
  480. * @param {int} kern_size kernel rows
  481. * @param {int} stride_size stride rows
  482. * @param {int} pad_size padding rows
  483. * @param {table|int} flags optional flags
  484. * @return {kann_node} kann node object (should be used to combine ANN)
  485. */
  486. static int
  487. lua_kann_layer_conv1d (lua_State *L)
  488. {
  489. kad_node_t *in = lua_check_kann_node (L, 1);
  490. int n_flt = luaL_checkinteger (L, 2);
  491. int k_size = luaL_checkinteger (L, 3);
  492. int stride = luaL_checkinteger (L, 4);
  493. int pad = luaL_checkinteger (L, 5);
  494. if (in != NULL) {
  495. kad_node_t *t;
  496. t = kann_layer_conv1d (in, n_flt, k_size, stride, pad);
  497. PROCESS_KAD_FLAGS (t, 6);
  498. PUSH_KAD_NODE (t);
  499. }
  500. else {
  501. return luaL_error (L, "invalid arguments, input, nflt, k, stride, pad required");
  502. }
  503. return 1;
  504. }
  505. /***
  506. * @function kann.layer.cost(in, nout, cost_type[, flags])
  507. * Creates 1D convolution layer
  508. * @param {kann_node} in kann node
  509. * @param {int} nout number of outputs
  510. * @param {int} cost_type see kann.cost table
  511. * @param {table|int} flags optional flags
  512. * @return {kann_node} kann node object (should be used to combine ANN)
  513. */
  514. static int
  515. lua_kann_layer_cost (lua_State *L)
  516. {
  517. kad_node_t *in = lua_check_kann_node (L, 1);
  518. int nout = luaL_checkinteger (L, 2);
  519. int cost_type = luaL_checkinteger (L, 3);
  520. if (in != NULL && nout > 0) {
  521. kad_node_t *t;
  522. t = kann_layer_cost (in, nout, cost_type);
  523. PROCESS_KAD_FLAGS (t, 4);
  524. PUSH_KAD_NODE (t);
  525. }
  526. else {
  527. return luaL_error (L, "invalid arguments, input, nout and cost_type are required");
  528. }
  529. return 1;
  530. }
  531. /* Generic helpers */
  532. static int
  533. lua_kann_call_unary_function (lua_State *L, const char *name,
  534. kad_node_t *(*func)(kad_node_t *))
  535. {
  536. kad_node_t *in = lua_check_kann_node (L, 1);
  537. if (in != NULL) {
  538. kad_node_t *t;
  539. t = func (in);
  540. PUSH_KAD_NODE (t);
  541. }
  542. else {
  543. return luaL_error (L, "invalid arguments for %s, input required", name);
  544. }
  545. return 1;
  546. }
  547. static int
  548. lua_kann_call_binary_function (lua_State *L, const char *name,
  549. kad_node_t *(*func)(kad_node_t *, kad_node_t *))
  550. {
  551. kad_node_t *x = lua_check_kann_node (L, 1);
  552. kad_node_t *y = lua_check_kann_node (L, 2);
  553. if (x != NULL && y != NULL) {
  554. kad_node_t *t;
  555. t = func (x, y);
  556. PUSH_KAD_NODE (t);
  557. }
  558. else {
  559. return luaL_error (L, "invalid arguments for %s, 2 inputs required", name);
  560. }
  561. return 1;
  562. }
  563. #define LUA_UNARY_TRANSFORM_FUNC_IMPL(name) \
  564. static int lua_kann_transform_ ##name (lua_State *L) \
  565. { \
  566. return lua_kann_call_unary_function(L, #name, kad_##name); \
  567. }
  568. #define LUA_BINARY_TRANSFORM_FUNC_IMPL(name) \
  569. static int lua_kann_transform_ ##name (lua_State *L) \
  570. { \
  571. return lua_kann_call_binary_function(L, #name, kad_##name); \
  572. }
  573. #define LUA_LOSS_FUNC_IMPL(name) \
  574. static int lua_kann_loss_ ##name (lua_State *L) \
  575. { \
  576. return lua_kann_call_binary_function(L, #name, kad_##name); \
  577. }
  578. /* Transform functions registered via macro helpers */
  579. LUA_BINARY_TRANSFORM_FUNC_IMPL (add)
  580. LUA_BINARY_TRANSFORM_FUNC_IMPL (sub)
  581. LUA_BINARY_TRANSFORM_FUNC_IMPL (mul)
  582. LUA_BINARY_TRANSFORM_FUNC_IMPL (cmul)
  583. LUA_BINARY_TRANSFORM_FUNC_IMPL (matmul)
  584. LUA_UNARY_TRANSFORM_FUNC_IMPL (square)
  585. LUA_UNARY_TRANSFORM_FUNC_IMPL (sigm)
  586. LUA_UNARY_TRANSFORM_FUNC_IMPL (tanh)
  587. LUA_UNARY_TRANSFORM_FUNC_IMPL (relu)
  588. LUA_UNARY_TRANSFORM_FUNC_IMPL (softmax)
  589. LUA_UNARY_TRANSFORM_FUNC_IMPL (1minus)
  590. LUA_UNARY_TRANSFORM_FUNC_IMPL (exp)
  591. LUA_UNARY_TRANSFORM_FUNC_IMPL (log)
  592. LUA_UNARY_TRANSFORM_FUNC_IMPL (sin)
  593. /* Generic cost functions */
  594. LUA_LOSS_FUNC_IMPL (mse)
  595. LUA_LOSS_FUNC_IMPL (ce_multi)
  596. LUA_LOSS_FUNC_IMPL (ce_bin)
  597. LUA_LOSS_FUNC_IMPL (ce_bin_neg)
  598. /* The only case of ternary weight function */
  599. static int
  600. lua_kann_loss_ce_multi_weighted (lua_State *L)
  601. {
  602. kad_node_t *pred = lua_check_kann_node (L, 1);
  603. kad_node_t *truth = lua_check_kann_node (L, 2);
  604. kad_node_t *weight = lua_check_kann_node (L, 3);
  605. if (pred != NULL && truth != NULL && weight != NULL) {
  606. kad_node_t *t;
  607. t = kad_ce_multi_weighted (pred, truth, weight);
  608. PUSH_KAD_NODE (t);
  609. }
  610. else {
  611. return luaL_error (L, "invalid arguments for ce_multi_weighted, 3 inputs required");
  612. }
  613. return 1;
  614. }
  615. /* Creation functions */
  616. static int
  617. lua_kann_new_scalar (lua_State *L)
  618. {
  619. gint flag = luaL_checkinteger (L, 1);
  620. double x = luaL_checknumber (L, 2);
  621. kad_node_t *t;
  622. t = kann_new_scalar (flag, x);
  623. PROCESS_KAD_FLAGS (t, 3);
  624. PUSH_KAD_NODE (t);
  625. return 1;
  626. }
  627. static int
  628. lua_kann_new_weight (lua_State *L)
  629. {
  630. gint nrow = luaL_checkinteger (L, 1);
  631. gint ncol = luaL_checkinteger (L, 2);
  632. kad_node_t *t;
  633. t = kann_new_weight (nrow, ncol);
  634. PROCESS_KAD_FLAGS (t, 3);
  635. PUSH_KAD_NODE (t);
  636. return 1;
  637. }
  638. static int
  639. lua_kann_new_bias (lua_State *L)
  640. {
  641. gint n = luaL_checkinteger (L, 1);
  642. kad_node_t *t;
  643. t = kann_new_bias (n);
  644. PROCESS_KAD_FLAGS (t, 2);
  645. PUSH_KAD_NODE (t);
  646. return 1;
  647. }
  648. static int
  649. lua_kann_new_weight_conv2d (lua_State *L)
  650. {
  651. gint nout = luaL_checkinteger (L, 1);
  652. gint nin = luaL_checkinteger (L, 2);
  653. gint krow = luaL_checkinteger (L, 3);
  654. gint kcol = luaL_checkinteger (L, 4);
  655. kad_node_t *t;
  656. t = kann_new_weight_conv2d (nout, nin, krow, kcol);
  657. PROCESS_KAD_FLAGS (t, 5);
  658. PUSH_KAD_NODE (t);
  659. return 1;
  660. }
  661. static int
  662. lua_kann_new_weight_conv1d (lua_State *L)
  663. {
  664. gint nout = luaL_checkinteger (L, 1);
  665. gint nin = luaL_checkinteger (L, 2);
  666. gint klen = luaL_checkinteger (L, 3);
  667. kad_node_t *t;
  668. t = kann_new_weight_conv1d (nout, nin, klen);
  669. PROCESS_KAD_FLAGS (t, 4);
  670. PUSH_KAD_NODE (t);
  671. return 1;
  672. }
  673. static int
  674. lua_kann_new_leaf (lua_State *L)
  675. {
  676. gint dim = luaL_checkinteger (L, 1), i, *ar;
  677. kad_node_t *t;
  678. if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
  679. ar = g_malloc0 (sizeof (ar) * dim);
  680. for (i = 0; i < dim; i ++) {
  681. lua_rawgeti (L, 2, i + 1);
  682. ar[i] = lua_tointeger (L, -1);
  683. lua_pop (L, 1);
  684. }
  685. t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);
  686. PROCESS_KAD_FLAGS (t, 3);
  687. PUSH_KAD_NODE (t);
  688. g_free (ar);
  689. }
  690. else {
  691. return luaL_error (L, "invalid arguments for new.leaf, "
  692. "dim and vector of elements are required");
  693. }
  694. return 1;
  695. }
  696. static int
  697. lua_kann_new_kann (lua_State *L)
  698. {
  699. kad_node_t *cost = lua_check_kann_node (L, 1);
  700. kann_t *k;
  701. if (cost) {
  702. k = kann_new (cost, 0);
  703. PUSH_KAN_NETWORK (k);
  704. }
  705. else {
  706. return luaL_error (L, "invalid arguments for new.kann, "
  707. "cost node is required");
  708. }
  709. return 1;
  710. }
  711. static int
  712. lua_kann_destroy (lua_State *L)
  713. {
  714. kann_t *k = lua_check_kann (L, 1);
  715. kann_delete (k);
  716. return 0;
  717. }
  718. static int
  719. lua_kann_save (lua_State *L)
  720. {
  721. kann_t *k = lua_check_kann (L, 1);
  722. if (k) {
  723. if (lua_istable (L, 2)) {
  724. lua_getfield (L, 2, "filename");
  725. if (lua_isstring (L, -1)) {
  726. const gchar *fname = lua_tostring (L, -1);
  727. FILE *f;
  728. f = fopen (fname, "w");
  729. if (!f) {
  730. lua_pop (L, 1);
  731. return luaL_error (L, "cannot open %s for writing: %s",
  732. fname, strerror (errno));
  733. }
  734. kann_save_fp (f, k);
  735. fclose (f);
  736. lua_pushboolean (L, true);
  737. }
  738. else {
  739. lua_pop (L, 1);
  740. return luaL_error (L, "invalid arguments: missing filename");
  741. }
  742. lua_pop (L, 1);
  743. }
  744. else {
  745. /* Save to Rspamd text */
  746. #ifndef HAVE_OPENMEMSTREAM
  747. return luaL_error (L, "no support of saving to memory on your system");
  748. #endif
  749. FILE *f;
  750. char *buf = NULL;
  751. size_t buflen;
  752. struct rspamd_lua_text *t;
  753. f = open_memstream (&buf, &buflen);
  754. g_assert (f != NULL);
  755. kann_save_fp (f, k);
  756. fclose (f);
  757. t = lua_newuserdata (L, sizeof (*t));
  758. rspamd_lua_setclass (L, "rspamd{text}", -1);
  759. t->flags = RSPAMD_TEXT_FLAG_OWN;
  760. t->start = (const gchar *)buf;
  761. t->len = buflen;
  762. }
  763. }
  764. else {
  765. return luaL_error (L, "invalid arguments");
  766. }
  767. return 1;
  768. }
  769. static int
  770. lua_kann_load (lua_State *L)
  771. {
  772. kann_t *k;
  773. FILE *f = NULL;
  774. if (lua_istable (L, 1)) {
  775. lua_getfield (L, 2, "filename");
  776. if (lua_isstring (L, -1)) {
  777. const gchar *fname = lua_tostring (L, -1);
  778. f = fopen (fname, "rb");
  779. }
  780. else {
  781. lua_pop (L, 1);
  782. return luaL_error (L, "invalid arguments: missing filename");
  783. }
  784. lua_pop (L, 1);
  785. }
  786. else if (lua_isstring (L, 1)) {
  787. gsize dlen;
  788. const gchar *data;
  789. data = lua_tolstring (L, 1, &dlen);
  790. #ifndef HAVE_FMEMOPEN
  791. return luaL_error (L, "no support of loading from memory on your system");
  792. #endif
  793. f = fmemopen ((void *)data, dlen, "rb");
  794. }
  795. else if (lua_isuserdata (L, 1)) {
  796. struct rspamd_lua_text *t;
  797. t = lua_check_text (L, 1);
  798. #ifndef HAVE_FMEMOPEN
  799. return luaL_error (L, "no support of loading from memory on your system");
  800. #endif
  801. f = fmemopen ((void *)t->start, t->len, "rb");
  802. }
  803. if (f == NULL) {
  804. return luaL_error (L, "invalid arguments or cannot open file");
  805. }
  806. k = kann_load_fp (f);
  807. fclose (f);
  808. if (k == NULL) {
  809. lua_pushnil (L);
  810. }
  811. else {
  812. PUSH_KAN_NETWORK (k);
  813. }
  814. return 1;
  815. }
  816. struct rspamd_kann_train_cbdata {
  817. lua_State *L;
  818. kann_t *k;
  819. gint cbref;
  820. };
  821. static void
  822. lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud)
  823. {
  824. struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud;
  825. if (cbd->cbref != -1) {
  826. gint err_idx;
  827. lua_State *L = cbd->L;
  828. lua_pushcfunction (L, &rspamd_lua_traceback);
  829. err_idx = lua_gettop (L);
  830. lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref);
  831. lua_pushinteger (L, iter);
  832. lua_pushnumber (L, train_cost);
  833. lua_pushnumber (L, val_cost);
  834. if (lua_pcall (L, 3, 0, err_idx) != 0) {
  835. msg_err ("cannot run lua train callback: %s",
  836. lua_tostring (L, -1));
  837. }
  838. lua_settop (L, err_idx - 1);
  839. }
  840. }
  841. #define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0)
  842. static int
  843. lua_kann_train1 (lua_State *L)
  844. {
  845. kann_t *k = lua_check_kann (L, 1);
  846. struct rspamd_lua_tensor *pca = NULL;
  847. /* Default train params */
  848. double lr = 0.001;
  849. gint64 mini_size = 64;
  850. gint64 max_epoch = 25;
  851. gint64 max_drop_streak = 10;
  852. double frac_val = 0.1;
  853. gint cbref = -1;
  854. if (k && lua_istable (L, 2) && lua_istable (L, 3)) {
  855. int n = rspamd_lua_table_size (L, 2);
  856. int n_in = kann_dim_in (k);
  857. int n_out = kann_dim_out (k);
  858. if (n_in <= 0) {
  859. return luaL_error (L, "invalid inputs count: %d", n_in);
  860. }
  861. if (n_out <= 0) {
  862. return luaL_error (L, "invalid outputs count: %d", n_in);
  863. }
  864. if (n != rspamd_lua_table_size (L, 3) || n == 0) {
  865. return luaL_error (L, "invalid dimensions: outputs size must be "
  866. "equal to inputs and non zero");
  867. }
  868. if (lua_istable (L, 4)) {
  869. GError *err = NULL;
  870. if (!rspamd_lua_parse_table_arguments (L, 4, &err,
  871. RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
  872. "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
  873. &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
  874. n = luaL_error (L, "invalid params: %s",
  875. err ? err->message : "unknown error");
  876. g_error_free (err);
  877. return n;
  878. }
  879. }
  880. if (pca) {
  881. /* Check pca matrix validity */
  882. if (pca->ndims != 2) {
  883. return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
  884. }
  885. if (pca->dim[0] != n_in) {
  886. return luaL_error (L, "invalid pca tensor: "
  887. "matrix must have %d rows and it has %d rows instead",
  888. n_in, pca->dim[0]);
  889. }
  890. }
  891. float **x, **y, *tmp_row = NULL;
  892. /* Fill vectors row by row */
  893. x = (float **)g_malloc0 (sizeof (float *) * n);
  894. y = (float **)g_malloc0 (sizeof (float *) * n);
  895. if (pca) {
  896. tmp_row = g_malloc (sizeof (float) * pca->dim[1]);
  897. }
  898. for (int s = 0; s < n; s ++) {
  899. /* Inputs */
  900. lua_rawgeti (L, 2, s + 1);
  901. x[s] = (float *)g_malloc (sizeof (float) * n_in);
  902. if (pca == NULL) {
  903. if (rspamd_lua_table_size (L, -1) != n_in) {
  904. FREE_VEC (x, n);
  905. FREE_VEC (y, n);
  906. n = luaL_error (L, "invalid params at pos %d: "
  907. "bad input dimension %d; %d expected",
  908. s + 1,
  909. (int) rspamd_lua_table_size (L, -1),
  910. n_in);
  911. lua_pop (L, 1);
  912. return n;
  913. }
  914. for (int i = 0; i < n_in; i++) {
  915. lua_rawgeti (L, -1, i + 1);
  916. x[s][i] = lua_tonumber (L, -1);
  917. lua_pop (L, 1);
  918. }
  919. }
  920. else {
  921. if (rspamd_lua_table_size (L, -1) != pca->dim[1]) {
  922. FREE_VEC (x, n);
  923. FREE_VEC (y, n);
  924. g_free (tmp_row);
  925. n = luaL_error (L, "(pca on) invalid params at pos %d: "
  926. "bad input dimension %d; %d expected",
  927. s + 1,
  928. (int) rspamd_lua_table_size (L, -1),
  929. pca->dim[1]);
  930. lua_pop (L, 1);
  931. return n;
  932. }
  933. for (int i = 0; i < pca->dim[1]; i++) {
  934. lua_rawgeti (L, -1, i + 1);
  935. tmp_row[i] = lua_tonumber (L, -1);
  936. lua_pop (L, 1);
  937. }
  938. kad_sgemm_simple (0, 1, 1, n_in,
  939. pca->dim[1], tmp_row, pca->data,
  940. x[s]);
  941. }
  942. lua_pop (L, 1);
  943. /* Outputs */
  944. y[s] = (float *)g_malloc (sizeof (float) * n_out);
  945. lua_rawgeti (L, 3, s + 1);
  946. if (rspamd_lua_table_size (L, -1) != n_out) {
  947. FREE_VEC (x, n);
  948. FREE_VEC (y, n);
  949. g_free (tmp_row);
  950. n = luaL_error (L, "invalid params at pos %d: "
  951. "bad output dimension %d; "
  952. "%d expected",
  953. s + 1,
  954. (int)rspamd_lua_table_size (L, -1),
  955. n_out);
  956. lua_pop (L, 1);
  957. return n;
  958. }
  959. for (int i = 0; i < n_out; i ++) {
  960. lua_rawgeti (L, -1, i + 1);
  961. y[s][i] = lua_tonumber (L, -1);
  962. lua_pop (L, 1);
  963. }
  964. lua_pop (L, 1);
  965. }
  966. struct rspamd_kann_train_cbdata cbd;
  967. cbd.cbref = cbref;
  968. cbd.k = k;
  969. cbd.L = L;
  970. int niters = kann_train_fnn1 (k, lr,
  971. mini_size, max_epoch, max_drop_streak,
  972. frac_val, n, x, y, lua_kann_train_cb, &cbd);
  973. lua_pushinteger (L, niters);
  974. FREE_VEC (x, n);
  975. FREE_VEC (y, n);
  976. g_free (tmp_row);
  977. }
  978. else {
  979. return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
  980. " optional params are expected");
  981. }
  982. return 1;
  983. }
  984. static int
  985. lua_kann_apply1 (lua_State *L)
  986. {
  987. kann_t *k = lua_check_kann (L, 1);
  988. struct rspamd_lua_tensor *pca = NULL;
  989. if (k) {
  990. if (lua_istable (L, 2)) {
  991. gsize vec_len = rspamd_lua_table_size (L, 2);
  992. float *vec = (float *) g_malloc (sizeof (float) * vec_len),
  993. *pca_out = NULL;
  994. int i_out;
  995. int n_in = kann_dim_in (k);
  996. if (n_in <= 0) {
  997. g_free (vec);
  998. return luaL_error (L, "invalid inputs count: %d", n_in);
  999. }
  1000. if (lua_isuserdata (L, 3)) {
  1001. pca = lua_check_tensor (L, 3);
  1002. if (pca) {
  1003. if (pca->ndims != 2) {
  1004. g_free (vec);
  1005. return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
  1006. }
  1007. if (pca->dim[0] != n_in) {
  1008. g_free (vec);
  1009. return luaL_error (L, "invalid pca tensor: "
  1010. "matrix must have %d rows and it has %d rows instead",
  1011. n_in, pca->dim[0]);
  1012. }
  1013. }
  1014. else {
  1015. g_free (vec);
  1016. return luaL_error (L, "invalid params: pca matrix expected");
  1017. }
  1018. }
  1019. else {
  1020. if (n_in != vec_len) {
  1021. g_free (vec);
  1022. return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
  1023. (int) vec_len, n_in);
  1024. }
  1025. }
  1026. for (gsize i = 0; i < vec_len; i++) {
  1027. lua_rawgeti (L, 2, i + 1);
  1028. vec[i] = lua_tonumber (L, -1);
  1029. lua_pop (L, 1);
  1030. }
  1031. i_out = kann_find (k, KANN_F_OUT, 0);
  1032. if (i_out <= 0) {
  1033. g_free (vec);
  1034. return luaL_error (L, "invalid ANN: output layer is missing or is "
  1035. "at the input pos");
  1036. }
  1037. kann_set_batch_size (k, 1);
  1038. if (pca) {
  1039. pca_out = g_malloc (sizeof (float) * n_in);
  1040. kad_sgemm_simple (0, 1, 1, n_in,
  1041. vec_len, vec, pca->data,
  1042. pca_out);
  1043. kann_feed_bind (k, KANN_F_IN, 0, &pca_out);
  1044. }
  1045. else {
  1046. kann_feed_bind (k, KANN_F_IN, 0, &vec);
  1047. }
  1048. kad_eval_at (k->n, k->v, i_out);
  1049. gsize outlen = kad_len (k->v[i_out]);
  1050. lua_createtable (L, outlen, 0);
  1051. for (gsize i = 0; i < outlen; i++) {
  1052. lua_pushnumber (L, k->v[i_out]->x[i]);
  1053. lua_rawseti (L, -2, i + 1);
  1054. }
  1055. g_free (vec);
  1056. g_free (pca_out);
  1057. }
  1058. else if (lua_isuserdata (L, 2)) {
  1059. struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);
  1060. if (t && t->ndims == 1) {
  1061. int i_out;
  1062. int n_in = kann_dim_in (k);
  1063. if (n_in != t->dim[0]) {
  1064. return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
  1065. (int) t->dim[0], n_in);
  1066. }
  1067. i_out = kann_find (k, KANN_F_OUT, 0);
  1068. if (i_out <= 0) {
  1069. return luaL_error (L, "invalid ANN: output layer is missing or is "
  1070. "at the input pos");
  1071. }
  1072. kann_set_batch_size (k, 1);
  1073. kann_feed_bind (k, KANN_F_IN, 0, &t->data);
  1074. kad_eval_at (k->n, k->v, i_out);
  1075. gint outlen = kad_len (k->v[i_out]);
  1076. struct rspamd_lua_tensor *out;
  1077. out = lua_newtensor (L, 1, &outlen, false, false);
  1078. /* Ensure that kann and tensor have the same understanding of floats */
  1079. G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t));
  1080. memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float));
  1081. }
  1082. else {
  1083. return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
  1084. }
  1085. }
  1086. else {
  1087. return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
  1088. }
  1089. }
  1090. else {
  1091. return luaL_error (L, "invalid arguments: rspamd{kann} expected");
  1092. }
  1093. return 1;
  1094. }