You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_kann.c 27KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204
  1. /*-
  2. * Copyright 2019 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "contrib/kann/kann.h"
  18. /***
  19. * @module rspamd_kann
  20. * `rspamd_kann` is a Lua interface to kann library
  21. */
  22. #define KANN_NODE_CLASS "rspamd{kann_node}"
  23. #define KANN_NETWORK_CLASS "rspamd{kann}"
  24. /* Simple macros to define behaviour */
  25. #define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
  26. #define KANN_LAYER_INTERFACE(name) {#name, lua_kann_layer_ ## name}
  27. #define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_ ## name (lua_State *L)
  28. #define KANN_TRANSFORM_INTERFACE(name) {#name, lua_kann_transform_ ## name}
  29. #define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
  30. #define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
  31. #define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
  32. #define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}
  33. /*
  34. * Forwarded declarations
  35. */
  36. static kad_node_t *lua_check_kann_node (lua_State *L, int pos);
  37. /* Layers */
  38. KANN_LAYER_DEF(input);
  39. KANN_LAYER_DEF(dense);
  40. KANN_LAYER_DEF(layernorm);
  41. KANN_LAYER_DEF(rnn);
  42. KANN_LAYER_DEF(lstm);
  43. KANN_LAYER_DEF(gru);
  44. KANN_LAYER_DEF(conv2d);
  45. KANN_LAYER_DEF(conv1d);
  46. KANN_LAYER_DEF(cost);
  47. static luaL_reg rspamd_kann_layers_f[] = {
  48. KANN_LAYER_INTERFACE(input),
  49. KANN_LAYER_INTERFACE(dense),
  50. KANN_LAYER_INTERFACE(layernorm),
  51. KANN_LAYER_INTERFACE(rnn),
  52. KANN_LAYER_INTERFACE(lstm),
  53. KANN_LAYER_INTERFACE(gru),
  54. KANN_LAYER_INTERFACE(conv2d),
  55. KANN_LAYER_INTERFACE(conv1d),
  56. KANN_LAYER_INTERFACE(cost),
  57. {NULL, NULL},
  58. };
  59. /* Transition and composition functions */
  60. /* General transform */
  61. KANN_TRANSFORM_DEF (add);
  62. KANN_TRANSFORM_DEF (sub);
  63. KANN_TRANSFORM_DEF (mul);
  64. KANN_TRANSFORM_DEF (cmul);
  65. KANN_TRANSFORM_DEF (matmul);
  66. KANN_TRANSFORM_DEF (square);
  67. KANN_TRANSFORM_DEF (sigm);
  68. KANN_TRANSFORM_DEF (tanh);
  69. KANN_TRANSFORM_DEF (relu);
  70. KANN_TRANSFORM_DEF (softmax);
  71. KANN_TRANSFORM_DEF (1minus);
  72. KANN_TRANSFORM_DEF (exp);
  73. KANN_TRANSFORM_DEF (log);
  74. KANN_TRANSFORM_DEF (sin);
  75. static luaL_reg rspamd_kann_transform_f[] = {
  76. KANN_TRANSFORM_INTERFACE (add),
  77. KANN_TRANSFORM_INTERFACE (sub),
  78. KANN_TRANSFORM_INTERFACE (mul),
  79. KANN_TRANSFORM_INTERFACE (cmul),
  80. KANN_TRANSFORM_INTERFACE (matmul),
  81. KANN_TRANSFORM_INTERFACE (square),
  82. KANN_TRANSFORM_INTERFACE (sigm),
  83. KANN_TRANSFORM_INTERFACE (tanh),
  84. KANN_TRANSFORM_INTERFACE (relu),
  85. KANN_TRANSFORM_INTERFACE (softmax),
  86. KANN_TRANSFORM_INTERFACE (1minus),
  87. KANN_TRANSFORM_INTERFACE (exp),
  88. KANN_TRANSFORM_INTERFACE (log),
  89. KANN_TRANSFORM_INTERFACE (sin),
  90. {NULL, NULL},
  91. };
  92. /* Loss functions */
  93. KANN_LOSS_DEF (mse);
  94. KANN_LOSS_DEF (ce_multi);
  95. KANN_LOSS_DEF (ce_bin);
  96. KANN_LOSS_DEF (ce_bin_neg);
  97. KANN_LOSS_DEF (ce_multi_weighted);
  98. static luaL_reg rspamd_kann_loss_f[] = {
  99. KANN_LOSS_INTERFACE (mse),
  100. KANN_LOSS_INTERFACE (ce_multi),
  101. KANN_LOSS_INTERFACE (ce_bin),
  102. KANN_LOSS_INTERFACE (ce_bin_neg),
  103. KANN_LOSS_INTERFACE (ce_multi_weighted),
  104. {NULL, NULL},
  105. };
  106. /* Creation functions */
  107. KANN_NEW_DEF (leaf);
  108. KANN_NEW_DEF (scalar);
  109. KANN_NEW_DEF (weight);
  110. KANN_NEW_DEF (bias);
  111. KANN_NEW_DEF (weight_conv2d);
  112. KANN_NEW_DEF (weight_conv1d);
  113. KANN_NEW_DEF (kann);
  114. static luaL_reg rspamd_kann_new_f[] = {
  115. KANN_NEW_INTERFACE (leaf),
  116. KANN_NEW_INTERFACE (scalar),
  117. KANN_NEW_INTERFACE (weight),
  118. KANN_NEW_INTERFACE (bias),
  119. KANN_NEW_INTERFACE (weight_conv2d),
  120. KANN_NEW_INTERFACE (weight_conv1d),
  121. KANN_NEW_INTERFACE (kann),
  122. {NULL, NULL},
  123. };
  124. LUA_FUNCTION_DEF (kann, load);
  125. LUA_FUNCTION_DEF (kann, destroy);
  126. LUA_FUNCTION_DEF (kann, save);
  127. LUA_FUNCTION_DEF (kann, train1);
  128. LUA_FUNCTION_DEF (kann, apply1);
  129. static luaL_reg rspamd_kann_m[] = {
  130. LUA_INTERFACE_DEF (kann, save),
  131. LUA_INTERFACE_DEF (kann, train1),
  132. LUA_INTERFACE_DEF (kann, apply1),
  133. {"__gc", lua_kann_destroy},
  134. {NULL, NULL},
  135. };
  136. static int
  137. rspamd_kann_table_to_flags (lua_State *L, int table_pos)
  138. {
  139. int result = 0;
  140. lua_pushvalue (L, table_pos);
  141. for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
  142. int fl = lua_tointeger (L, -1);
  143. result |= fl;
  144. }
  145. lua_pop (L, 1);
  146. return result;
  147. }
  148. static gint
  149. lua_load_kann (lua_State * L)
  150. {
  151. lua_newtable (L);
  152. /* Flags */
  153. lua_pushstring (L, "flag");
  154. lua_newtable (L);
  155. lua_pushinteger (L, KANN_F_IN);
  156. lua_setfield (L, -2, "in");
  157. lua_pushinteger (L, KANN_F_COST);
  158. lua_setfield (L, -2, "cost");
  159. lua_pushinteger (L, KANN_F_OUT);
  160. lua_setfield (L, -2, "out");
  161. lua_pushinteger (L, KANN_F_TRUTH);
  162. lua_setfield (L, -2, "truth");
  163. lua_settable (L, -3);
  164. /* Cost type */
  165. lua_pushstring (L, "cost");
  166. lua_newtable (L);
  167. /* binary cross-entropy cost, used with sigmoid */
  168. lua_pushinteger (L, KANN_C_CEB);
  169. lua_setfield (L, -2, "ceb");
  170. /* multi-class cross-entropy cost, used with softmax */
  171. lua_pushinteger (L, KANN_C_CEM);
  172. lua_setfield (L, -2, "cem");
  173. /* binary cross-entropy-like cost, used with tanh */
  174. lua_pushinteger (L, KANN_C_CEB_NEG);
  175. lua_setfield (L, -2, "ceb_neg");
  176. lua_pushinteger (L, KANN_C_MSE);
  177. lua_setfield (L, -2, "mse");
  178. lua_settable (L, -3);
  179. /* RNN flag */
  180. lua_pushstring (L, "rnn");
  181. lua_newtable (L);
  182. /* apply layer normalization */
  183. lua_pushinteger (L, KANN_RNN_NORM);
  184. lua_setfield (L, -2, "norm");
  185. /* take the initial hidden values as variables */
  186. lua_pushinteger (L, KANN_RNN_VAR_H0);
  187. lua_setfield (L, -2, "var_h0");
  188. lua_settable (L, -3);
  189. /* Layers */
  190. lua_pushstring (L, "layer");
  191. lua_newtable (L);
  192. luaL_register (L, NULL, rspamd_kann_layers_f);
  193. lua_settable (L, -3);
  194. /* Transforms */
  195. lua_pushstring (L, "transform");
  196. lua_newtable (L);
  197. luaL_register (L, NULL, rspamd_kann_transform_f);
  198. lua_settable (L, -3);
  199. /* Cost */
  200. lua_pushstring (L, "loss");
  201. lua_newtable (L);
  202. luaL_register (L, NULL, rspamd_kann_loss_f);
  203. lua_settable (L, -3);
  204. /* Create functions */
  205. lua_pushstring (L, "new");
  206. lua_newtable (L);
  207. luaL_register (L, NULL, rspamd_kann_new_f);
  208. lua_settable (L, -3);
  209. /* Load ann from memory or file */
  210. lua_pushstring (L, "load");
  211. lua_pushcfunction (L, lua_kann_load);
  212. lua_settable (L, -3);
  213. return 1;
  214. }
  215. static kad_node_t *
  216. lua_check_kann_node (lua_State *L, int pos)
  217. {
  218. void *ud = rspamd_lua_check_udata (L, pos, KANN_NODE_CLASS);
  219. luaL_argcheck (L, ud != NULL, pos, "'kann_node' expected");
  220. return ud ? *((kad_node_t **)ud) : NULL;
  221. }
  222. static kann_t *
  223. lua_check_kann (lua_State *L, int pos)
  224. {
  225. void *ud = rspamd_lua_check_udata (L, pos, KANN_NETWORK_CLASS);
  226. luaL_argcheck (L, ud != NULL, pos, "'kann' expected");
  227. return ud ? *((kann_t **)ud) : NULL;
  228. }
  229. void luaopen_kann (lua_State *L)
  230. {
  231. /* Metatables */
  232. rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
  233. lua_pop (L, 1); /* No need in metatable... */
  234. rspamd_lua_new_class (L, KANN_NETWORK_CLASS, rspamd_kann_m);
  235. lua_pop (L, 1); /* No need in metatable... */
  236. rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
  237. lua_settop (L, 0);
  238. }
  239. /* Layers implementation */
  240. #define PUSH_KAD_NODE(n) do { \
  241. kad_node_t **pt; \
  242. pt = lua_newuserdata (L, sizeof (kad_node_t *)); \
  243. *pt = (n); \
  244. rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
  245. } while(0)
  246. #define PUSH_KAN_NETWORK(n) do { \
  247. kann_t **pn; \
  248. pn = lua_newuserdata (L, sizeof (kann_t *)); \
  249. *pn = (n); \
  250. rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
  251. } while(0)
  252. #define PROCESS_KAD_FLAGS(n, pos) do { \
  253. int fl = 0; \
  254. if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
  255. else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
  256. (n)->ext_flag |= fl; \
  257. }while(0)
  258. /***
  259. * @function kann.layer.input(ninputs[, flags])
  260. * Creates an input layer for ANN
  261. * @param {int} ninputs number of inputs
  262. * @param {table|int} flags optional flags
  263. * @return {kann_node} kann node object (should be used to combine ANN)
  264. */
  265. static int
  266. lua_kann_layer_input (lua_State *L)
  267. {
  268. gint nnodes = luaL_checkinteger (L, 1);
  269. if (nnodes > 0) {
  270. kad_node_t *t;
  271. t = kann_layer_input (nnodes);
  272. PROCESS_KAD_FLAGS (t, 2);
  273. PUSH_KAD_NODE (t);
  274. }
  275. else {
  276. return luaL_error (L, "invalid arguments, nnodes required");
  277. }
  278. return 1;
  279. }
  280. /***
  281. * @function kann.layer.dense(in, ninputs[, flags])
  282. * Creates a dense layer (e.g. for hidden layer)
  283. * @param {kann_node} in kann node
  284. * @param {int} ninputs number of dense nodes
  285. * @param {table|int} flags optional flags
  286. * @return {kann_node} kann node object (should be used to combine ANN)
  287. */
  288. static int
  289. lua_kann_layer_dense (lua_State *L)
  290. {
  291. kad_node_t *in = lua_check_kann_node (L, 1);
  292. gint nnodes = luaL_checkinteger (L, 2);
  293. if (in != NULL && nnodes > 0) {
  294. kad_node_t *t;
  295. t = kann_layer_dense (in, nnodes);
  296. PROCESS_KAD_FLAGS (t, 3);
  297. PUSH_KAD_NODE (t);
  298. }
  299. else {
  300. return luaL_error (L, "invalid arguments, input + nnodes required");
  301. }
  302. return 1;
  303. }
  304. /***
  305. * @function kann.layer.dropout(in, ratio[, flags])
  306. * Creates a dropout layer
  307. * @param {kann_node} in kann node
  308. * @param {float} ratio drop ratio
  309. * @param {table|int} flags optional flags
  310. * @return {kann_node} kann node object (should be used to combine ANN)
  311. */
  312. static int
  313. lua_kann_layer_layerdropout (lua_State *L)
  314. {
  315. kad_node_t *in = lua_check_kann_node (L, 1);
  316. double r = luaL_checknumber (L, 2);
  317. if (in != NULL) {
  318. kad_node_t *t;
  319. t = kann_layer_dropout (in, r);
  320. PROCESS_KAD_FLAGS (t, 3);
  321. PUSH_KAD_NODE (t);
  322. }
  323. else {
  324. return luaL_error (L, "invalid arguments, input + rate required");
  325. }
  326. return 1;
  327. }
  328. /***
  329. * @function kann.layer.dropout(in [, flags])
  330. * Creates a normalisation layer
  331. * @param {kann_node} in kann node
  332. * @param {table|int} flags optional flags
  333. * @return {kann_node} kann node object (should be used to combine ANN)
  334. */
  335. static int
  336. lua_kann_layer_layernorm (lua_State *L)
  337. {
  338. kad_node_t *in = lua_check_kann_node (L, 1);
  339. if (in != NULL) {
  340. kad_node_t *t;
  341. t = kann_layer_layernorm (in);
  342. PROCESS_KAD_FLAGS (t, 2);
  343. PUSH_KAD_NODE (t);
  344. }
  345. else {
  346. return luaL_error (L, "invalid arguments, input required");
  347. }
  348. return 1;
  349. }
  350. /***
  351. * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
  352. * Creates a recursive NN layer
  353. * @param {kann_node} in kann node
  354. * @param {int} nnodes number of cells
  355. * @param {int} rnnflags rnn flags
  356. * @param {table|int} flags optional flags
  357. * @return {kann_node} kann node object (should be used to combine ANN)
  358. */
  359. static int
  360. lua_kann_layer_rnn (lua_State *L)
  361. {
  362. kad_node_t *in = lua_check_kann_node (L, 1);
  363. gint nnodes = luaL_checkinteger (L, 2);
  364. gint rnnflags = 0;
  365. if (in != NULL && nnodes > 0) {
  366. kad_node_t *t;
  367. if (lua_type (L, 3) == LUA_TNUMBER) {
  368. rnnflags = lua_tointeger (L, 3);
  369. }
  370. t = kann_layer_rnn (in, nnodes, rnnflags);
  371. PROCESS_KAD_FLAGS (t, 4);
  372. PUSH_KAD_NODE (t);
  373. }
  374. else {
  375. return luaL_error (L, "invalid arguments, input + nnodes required");
  376. }
  377. return 1;
  378. }
  379. /***
  380. * @function kann.layer.lstm(in, nnodes[, rnn_flags, [, flags]])
  381. * Creates a recursive NN layer using LSTM cells
  382. * @param {kann_node} in kann node
  383. * @param {int} nnodes number of cells
  384. * @param {int} rnnflags rnn flags
  385. * @param {table|int} flags optional flags
  386. * @return {kann_node} kann node object (should be used to combine ANN)
  387. */
  388. static int
  389. lua_kann_layer_lstm (lua_State *L)
  390. {
  391. kad_node_t *in = lua_check_kann_node (L, 1);
  392. gint nnodes = luaL_checkinteger (L, 2);
  393. gint rnnflags = 0;
  394. if (in != NULL && nnodes > 0) {
  395. kad_node_t *t;
  396. if (lua_type (L, 3) == LUA_TNUMBER) {
  397. rnnflags = lua_tointeger (L, 3);
  398. }
  399. t = kann_layer_lstm (in, nnodes, rnnflags);
  400. PROCESS_KAD_FLAGS (t, 4);
  401. PUSH_KAD_NODE (t);
  402. }
  403. else {
  404. return luaL_error (L, "invalid arguments, input + nnodes required");
  405. }
  406. return 1;
  407. }
  408. /***
  409. * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
  410. * Creates a recursive NN layer using GRU cells
  411. * @param {kann_node} in kann node
  412. * @param {int} nnodes number of cells
  413. * @param {int} rnnflags rnn flags
  414. * @param {table|int} flags optional flags
  415. * @return {kann_node} kann node object (should be used to combine ANN)
  416. */
  417. static int
  418. lua_kann_layer_gru (lua_State *L)
  419. {
  420. kad_node_t *in = lua_check_kann_node (L, 1);
  421. gint nnodes = luaL_checkinteger (L, 2);
  422. gint rnnflags = 0;
  423. if (in != NULL && nnodes > 0) {
  424. kad_node_t *t;
  425. if (lua_type (L, 3) == LUA_TNUMBER) {
  426. rnnflags = lua_tointeger (L, 3);
  427. }
  428. t = kann_layer_gru (in, nnodes, rnnflags);
  429. PROCESS_KAD_FLAGS (t, 4);
  430. PUSH_KAD_NODE (t);
  431. }
  432. else {
  433. return luaL_error (L, "invalid arguments, input + nnodes required");
  434. }
  435. return 1;
  436. }
  437. /***
  438. * @function kann.layer.conv2d(in, n_flt, k_rows, k_cols, stride_rows, stride_cols, pad_rows, pad_columns[, flags])
  439. * Creates a 2D convolution layer
  440. * @param {kann_node} in kann node
  441. * @param {int} n_flt number of filters
  442. * @param {int} k_rows kernel rows
  443. * @param {int} k_cols kernel columns
  444. * @param {int} stride_rows stride rows
  445. * @param {int} stride_cols stride columns
  446. * @param {int} pad_rows padding rows
  447. * @param {int} pad_columns padding columns
  448. * @param {table|int} flags optional flags
  449. * @return {kann_node} kann node object (should be used to combine ANN)
  450. */
  451. static int
  452. lua_kann_layer_conv2d (lua_State *L)
  453. {
  454. kad_node_t *in = lua_check_kann_node (L, 1);
  455. int n_flt = luaL_checkinteger (L, 2);
  456. int k_rows = luaL_checkinteger (L, 3);
  457. int k_cols = luaL_checkinteger (L, 4);
  458. int stride_r = luaL_checkinteger (L, 5);
  459. int stride_c = luaL_checkinteger (L, 6);
  460. int pad_r = luaL_checkinteger (L, 7);
  461. int pad_c = luaL_checkinteger (L, 8);
  462. if (in != NULL) {
  463. kad_node_t *t;
  464. t = kann_layer_conv2d (in, n_flt, k_rows, k_cols, stride_r, stride_c,
  465. pad_r, pad_c);
  466. PROCESS_KAD_FLAGS (t, 9);
  467. PUSH_KAD_NODE (t);
  468. }
  469. else {
  470. return luaL_error (L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
  471. }
  472. return 1;
  473. }
  474. /***
  475. * @function kann.layer.conv1d(in, n_flt, kern_size, stride_size, pad_size[, flags])
  476. * Creates 1D convolution layer
  477. * @param {kann_node} in kann node
  478. * @param {int} n_flt number of filters
  479. * @param {int} kern_size kernel rows
  480. * @param {int} stride_size stride rows
  481. * @param {int} pad_size padding rows
  482. * @param {table|int} flags optional flags
  483. * @return {kann_node} kann node object (should be used to combine ANN)
  484. */
  485. static int
  486. lua_kann_layer_conv1d (lua_State *L)
  487. {
  488. kad_node_t *in = lua_check_kann_node (L, 1);
  489. int n_flt = luaL_checkinteger (L, 2);
  490. int k_size = luaL_checkinteger (L, 3);
  491. int stride = luaL_checkinteger (L, 4);
  492. int pad = luaL_checkinteger (L, 5);
  493. if (in != NULL) {
  494. kad_node_t *t;
  495. t = kann_layer_conv1d (in, n_flt, k_size, stride, pad);
  496. PROCESS_KAD_FLAGS (t, 6);
  497. PUSH_KAD_NODE (t);
  498. }
  499. else {
  500. return luaL_error (L, "invalid arguments, input, nflt, k, stride, pad required");
  501. }
  502. return 1;
  503. }
  504. /***
  505. * @function kann.layer.cost(in, nout, cost_type[, flags])
  506. * Creates 1D convolution layer
  507. * @param {kann_node} in kann node
  508. * @param {int} nout number of outputs
  509. * @param {int} cost_type see kann.cost table
  510. * @param {table|int} flags optional flags
  511. * @return {kann_node} kann node object (should be used to combine ANN)
  512. */
  513. static int
  514. lua_kann_layer_cost (lua_State *L)
  515. {
  516. kad_node_t *in = lua_check_kann_node (L, 1);
  517. int nout = luaL_checkinteger (L, 2);
  518. int cost_type = luaL_checkinteger (L, 3);
  519. if (in != NULL && nout > 0) {
  520. kad_node_t *t;
  521. t = kann_layer_cost (in, nout, cost_type);
  522. PROCESS_KAD_FLAGS (t, 4);
  523. PUSH_KAD_NODE (t);
  524. }
  525. else {
  526. return luaL_error (L, "invalid arguments, input, nout and cost_type are required");
  527. }
  528. return 1;
  529. }
  530. /* Generic helpers */
  531. static int
  532. lua_kann_call_unary_function (lua_State *L, const char *name,
  533. kad_node_t *(*func)(kad_node_t *))
  534. {
  535. kad_node_t *in = lua_check_kann_node (L, 1);
  536. if (in != NULL) {
  537. kad_node_t *t;
  538. t = func (in);
  539. PUSH_KAD_NODE (t);
  540. }
  541. else {
  542. return luaL_error (L, "invalid arguments for %s, input required", name);
  543. }
  544. return 1;
  545. }
  546. static int
  547. lua_kann_call_binary_function (lua_State *L, const char *name,
  548. kad_node_t *(*func)(kad_node_t *, kad_node_t *))
  549. {
  550. kad_node_t *x = lua_check_kann_node (L, 1);
  551. kad_node_t *y = lua_check_kann_node (L, 2);
  552. if (x != NULL && y != NULL) {
  553. kad_node_t *t;
  554. t = func (x, y);
  555. PUSH_KAD_NODE (t);
  556. }
  557. else {
  558. return luaL_error (L, "invalid arguments for %s, 2 inputs required", name);
  559. }
  560. return 1;
  561. }
  562. #define LUA_UNARY_TRANSFORM_FUNC_IMPL(name) \
  563. static int lua_kann_transform_ ##name (lua_State *L) \
  564. { \
  565. return lua_kann_call_unary_function(L, #name, kad_##name); \
  566. }
  567. #define LUA_BINARY_TRANSFORM_FUNC_IMPL(name) \
  568. static int lua_kann_transform_ ##name (lua_State *L) \
  569. { \
  570. return lua_kann_call_binary_function(L, #name, kad_##name); \
  571. }
  572. #define LUA_LOSS_FUNC_IMPL(name) \
  573. static int lua_kann_loss_ ##name (lua_State *L) \
  574. { \
  575. return lua_kann_call_binary_function(L, #name, kad_##name); \
  576. }
  577. /* Transform functions registered via macro helpers */
  578. LUA_BINARY_TRANSFORM_FUNC_IMPL (add)
  579. LUA_BINARY_TRANSFORM_FUNC_IMPL (sub)
  580. LUA_BINARY_TRANSFORM_FUNC_IMPL (mul)
  581. LUA_BINARY_TRANSFORM_FUNC_IMPL (cmul)
  582. LUA_BINARY_TRANSFORM_FUNC_IMPL (matmul)
  583. LUA_UNARY_TRANSFORM_FUNC_IMPL (square)
  584. LUA_UNARY_TRANSFORM_FUNC_IMPL (sigm)
  585. LUA_UNARY_TRANSFORM_FUNC_IMPL (tanh)
  586. LUA_UNARY_TRANSFORM_FUNC_IMPL (relu)
  587. LUA_UNARY_TRANSFORM_FUNC_IMPL (softmax)
  588. LUA_UNARY_TRANSFORM_FUNC_IMPL (1minus)
  589. LUA_UNARY_TRANSFORM_FUNC_IMPL (exp)
  590. LUA_UNARY_TRANSFORM_FUNC_IMPL (log)
  591. LUA_UNARY_TRANSFORM_FUNC_IMPL (sin)
  592. /* Generic cost functions */
  593. LUA_LOSS_FUNC_IMPL (mse)
  594. LUA_LOSS_FUNC_IMPL (ce_multi)
  595. LUA_LOSS_FUNC_IMPL (ce_bin)
  596. LUA_LOSS_FUNC_IMPL (ce_bin_neg)
  597. /* The only case of ternary weight function */
  598. static int
  599. lua_kann_loss_ce_multi_weighted (lua_State *L)
  600. {
  601. kad_node_t *pred = lua_check_kann_node (L, 1);
  602. kad_node_t *truth = lua_check_kann_node (L, 2);
  603. kad_node_t *weight = lua_check_kann_node (L, 3);
  604. if (pred != NULL && truth != NULL && weight != NULL) {
  605. kad_node_t *t;
  606. t = kad_ce_multi_weighted (pred, truth, weight);
  607. PUSH_KAD_NODE (t);
  608. }
  609. else {
  610. return luaL_error (L, "invalid arguments for ce_multi_weighted, 3 inputs required");
  611. }
  612. return 1;
  613. }
  614. /* Creation functions */
  615. static int
  616. lua_kann_new_scalar (lua_State *L)
  617. {
  618. gint flag = luaL_checkinteger (L, 1);
  619. double x = luaL_checknumber (L, 2);
  620. kad_node_t *t;
  621. t = kann_new_scalar (flag, x);
  622. PROCESS_KAD_FLAGS (t, 3);
  623. PUSH_KAD_NODE (t);
  624. return 1;
  625. }
  626. static int
  627. lua_kann_new_weight (lua_State *L)
  628. {
  629. gint nrow = luaL_checkinteger (L, 1);
  630. gint ncol = luaL_checkinteger (L, 2);
  631. kad_node_t *t;
  632. t = kann_new_weight (nrow, ncol);
  633. PROCESS_KAD_FLAGS (t, 3);
  634. PUSH_KAD_NODE (t);
  635. return 1;
  636. }
  637. static int
  638. lua_kann_new_bias (lua_State *L)
  639. {
  640. gint n = luaL_checkinteger (L, 1);
  641. kad_node_t *t;
  642. t = kann_new_bias (n);
  643. PROCESS_KAD_FLAGS (t, 2);
  644. PUSH_KAD_NODE (t);
  645. return 1;
  646. }
  647. static int
  648. lua_kann_new_weight_conv2d (lua_State *L)
  649. {
  650. gint nout = luaL_checkinteger (L, 1);
  651. gint nin = luaL_checkinteger (L, 2);
  652. gint krow = luaL_checkinteger (L, 3);
  653. gint kcol = luaL_checkinteger (L, 4);
  654. kad_node_t *t;
  655. t = kann_new_weight_conv2d (nout, nin, krow, kcol);
  656. PROCESS_KAD_FLAGS (t, 5);
  657. PUSH_KAD_NODE (t);
  658. return 1;
  659. }
  660. static int
  661. lua_kann_new_weight_conv1d (lua_State *L)
  662. {
  663. gint nout = luaL_checkinteger (L, 1);
  664. gint nin = luaL_checkinteger (L, 2);
  665. gint klen = luaL_checkinteger (L, 3);
  666. kad_node_t *t;
  667. t = kann_new_weight_conv1d (nout, nin, klen);
  668. PROCESS_KAD_FLAGS (t, 4);
  669. PUSH_KAD_NODE (t);
  670. return 1;
  671. }
  672. static int
  673. lua_kann_new_leaf (lua_State *L)
  674. {
  675. gint dim = luaL_checkinteger (L, 1), i, *ar;
  676. kad_node_t *t;
  677. if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
  678. ar = g_malloc0 (sizeof (ar) * dim);
  679. for (i = 0; i < dim; i ++) {
  680. lua_rawgeti (L, 2, i + 1);
  681. ar[i] = lua_tointeger (L, -1);
  682. lua_pop (L, 1);
  683. }
  684. t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);
  685. PROCESS_KAD_FLAGS (t, 3);
  686. PUSH_KAD_NODE (t);
  687. g_free (ar);
  688. }
  689. else {
  690. return luaL_error (L, "invalid arguments for new.leaf, "
  691. "dim and vector of elements are required");
  692. }
  693. return 1;
  694. }
  695. static int
  696. lua_kann_new_kann (lua_State *L)
  697. {
  698. kad_node_t *cost = lua_check_kann_node (L, 1);
  699. kann_t *k;
  700. if (cost) {
  701. k = kann_new (cost, 0);
  702. PUSH_KAN_NETWORK (k);
  703. }
  704. else {
  705. return luaL_error (L, "invalid arguments for new.kann, "
  706. "cost node is required");
  707. }
  708. return 1;
  709. }
  710. static int
  711. lua_kann_destroy (lua_State *L)
  712. {
  713. kann_t *k = lua_check_kann (L, 1);
  714. kann_delete (k);
  715. return 0;
  716. }
  717. static int
  718. lua_kann_save (lua_State *L)
  719. {
  720. kann_t *k = lua_check_kann (L, 1);
  721. if (k) {
  722. if (lua_istable (L, 2)) {
  723. lua_getfield (L, 2, "filename");
  724. if (lua_isstring (L, -1)) {
  725. const gchar *fname = lua_tostring (L, -1);
  726. FILE *f;
  727. f = fopen (fname, "w");
  728. if (!f) {
  729. lua_pop (L, 1);
  730. return luaL_error (L, "cannot open %s for writing: %s",
  731. fname, strerror (errno));
  732. }
  733. kann_save_fp (f, k);
  734. fclose (f);
  735. lua_pushboolean (L, true);
  736. }
  737. else {
  738. lua_pop (L, 1);
  739. return luaL_error (L, "invalid arguments: missing filename");
  740. }
  741. lua_pop (L, 1);
  742. }
  743. else {
  744. /* Save to Rspamd text */
  745. #ifndef HAVE_OPENMEMSTREAM
  746. return luaL_error (L, "no support of saving to memory on your system");
  747. #endif
  748. FILE *f;
  749. char *buf = NULL;
  750. size_t buflen;
  751. struct rspamd_lua_text *t;
  752. f = open_memstream (&buf, &buflen);
  753. g_assert (f != NULL);
  754. kann_save_fp (f, k);
  755. fclose (f);
  756. t = lua_newuserdata (L, sizeof (*t));
  757. rspamd_lua_setclass (L, "rspamd{text}", -1);
  758. t->flags = RSPAMD_TEXT_FLAG_OWN;
  759. t->start = (const gchar *)buf;
  760. t->len = buflen;
  761. }
  762. }
  763. else {
  764. return luaL_error (L, "invalid arguments");
  765. }
  766. return 1;
  767. }
  768. static int
  769. lua_kann_load (lua_State *L)
  770. {
  771. kann_t *k;
  772. FILE *f = NULL;
  773. if (lua_istable (L, 1)) {
  774. lua_getfield (L, 2, "filename");
  775. if (lua_isstring (L, -1)) {
  776. const gchar *fname = lua_tostring (L, -1);
  777. f = fopen (fname, "rb");
  778. }
  779. else {
  780. lua_pop (L, 1);
  781. return luaL_error (L, "invalid arguments: missing filename");
  782. }
  783. lua_pop (L, 1);
  784. }
  785. else if (lua_isstring (L, 1)) {
  786. gsize dlen;
  787. const gchar *data;
  788. data = lua_tolstring (L, 1, &dlen);
  789. #ifndef HAVE_FMEMOPEN
  790. return luaL_error (L, "no support of loading from memory on your system");
  791. #endif
  792. f = fmemopen ((void *)data, dlen, "rb");
  793. }
  794. else if (lua_isuserdata (L, 1)) {
  795. struct rspamd_lua_text *t;
  796. t = lua_check_text (L, 1);
  797. #ifndef HAVE_FMEMOPEN
  798. return luaL_error (L, "no support of loading from memory on your system");
  799. #endif
  800. f = fmemopen ((void *)t->start, t->len, "rb");
  801. }
  802. if (f == NULL) {
  803. return luaL_error (L, "invalid arguments or cannot open file");
  804. }
  805. k = kann_load_fp (f);
  806. fclose (f);
  807. if (k == NULL) {
  808. lua_pushnil (L);
  809. }
  810. else {
  811. PUSH_KAN_NETWORK (k);
  812. }
  813. return 1;
  814. }
  815. struct rspamd_kann_train_cbdata {
  816. lua_State *L;
  817. kann_t *k;
  818. gint cbref;
  819. };
  820. static void
  821. lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud)
  822. {
  823. struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud;
  824. if (cbd->cbref != -1) {
  825. gint err_idx;
  826. lua_State *L = cbd->L;
  827. lua_pushcfunction (L, &rspamd_lua_traceback);
  828. err_idx = lua_gettop (L);
  829. lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref);
  830. lua_pushinteger (L, iter);
  831. lua_pushnumber (L, train_cost);
  832. lua_pushnumber (L, val_cost);
  833. if (lua_pcall (L, 3, 0, err_idx) != 0) {
  834. msg_err ("cannot run lua train callback: %s",
  835. lua_tostring (L, -1));
  836. }
  837. lua_settop (L, err_idx - 1);
  838. }
  839. }
  840. #define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0)
  841. static int
  842. lua_kann_train1 (lua_State *L)
  843. {
  844. kann_t *k = lua_check_kann (L, 1);
  845. /* Default train params */
  846. double lr = 0.001;
  847. gint64 mini_size = 64;
  848. gint64 max_epoch = 25;
  849. gint64 max_drop_streak = 10;
  850. double frac_val = 0.1;
  851. gint cbref = -1;
  852. if (k && lua_istable (L, 2) && lua_istable (L, 3)) {
  853. int n = rspamd_lua_table_size (L, 2);
  854. int n_in = kann_dim_in (k);
  855. int n_out = kann_dim_out (k);
  856. if (n_in <= 0) {
  857. return luaL_error (L, "invalid inputs count: %d", n_in);
  858. }
  859. if (n_out <= 0) {
  860. return luaL_error (L, "invalid outputs count: %d", n_in);
  861. }
  862. if (n != rspamd_lua_table_size (L, 3) || n == 0) {
  863. return luaL_error (L, "invalid dimensions: outputs size must be "
  864. "equal to inputs and non zero");
  865. }
  866. if (lua_istable (L, 4)) {
  867. GError *err = NULL;
  868. if (!rspamd_lua_parse_table_arguments (L, 4, &err,
  869. RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
  870. "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
  871. &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
  872. n = luaL_error (L, "invalid params: %s",
  873. err ? err->message : "unknown error");
  874. g_error_free (err);
  875. return n;
  876. }
  877. }
  878. float **x, **y;
  879. /* Fill vectors */
  880. x = (float **)g_malloc0 (sizeof (float *) * n);
  881. y = (float **)g_malloc0 (sizeof (float *) * n);
  882. for (int s = 0; s < n; s ++) {
  883. /* Inputs */
  884. lua_rawgeti (L, 2, s + 1);
  885. x[s] = (float *)g_malloc (sizeof (float) * n_in);
  886. if (rspamd_lua_table_size (L, -1) != n_in) {
  887. FREE_VEC (x, n);
  888. FREE_VEC (y, n);
  889. n = luaL_error (L, "invalid params at pos %d: "
  890. "bad input dimension %d; %d expected",
  891. s + 1,
  892. (int)rspamd_lua_table_size (L, -1),
  893. n_in);
  894. return n;
  895. }
  896. for (int i = 0; i < n_in; i ++) {
  897. lua_rawgeti (L, -1, i + 1);
  898. x[s][i] = lua_tonumber (L, -1);
  899. lua_pop (L, 1);
  900. }
  901. lua_pop (L, 1);
  902. /* Outputs */
  903. y[s] = (float *)g_malloc (sizeof (float) * n_out);
  904. lua_rawgeti (L, 3, s + 1);
  905. if (rspamd_lua_table_size (L, -1) != n_out) {
  906. FREE_VEC (x, n);
  907. FREE_VEC (y, n);
  908. n = luaL_error (L, "invalid params at pos %d: "
  909. "bad output dimension %d; "
  910. "%d expected",
  911. s + 1,
  912. (int)rspamd_lua_table_size (L, -1),
  913. n_out);
  914. return n;
  915. }
  916. for (int i = 0; i < n_out; i ++) {
  917. lua_rawgeti (L, -1, i + 1);
  918. y[s][i] = lua_tonumber (L, -1);
  919. lua_pop (L, 1);
  920. }
  921. lua_pop (L, 1);
  922. }
  923. struct rspamd_kann_train_cbdata cbd;
  924. cbd.cbref = cbref;
  925. cbd.k = k;
  926. cbd.L = L;
  927. int niters = kann_train_fnn1 (k, lr,
  928. mini_size, max_epoch, max_drop_streak,
  929. frac_val, n, x, y, lua_kann_train_cb, &cbd);
  930. lua_pushinteger (L, niters);
  931. FREE_VEC (x, n);
  932. FREE_VEC (y, n);
  933. }
  934. else {
  935. return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
  936. " optional params are expected");
  937. }
  938. return 1;
  939. }
  940. static int
  941. lua_kann_apply1 (lua_State *L)
  942. {
  943. kann_t *k = lua_check_kann (L, 1);
  944. if (k && lua_istable (L, 2)) {
  945. gsize vec_len = rspamd_lua_table_size (L, 2);
  946. float *vec = (float *)g_malloc (sizeof (float) * vec_len);
  947. int i_out;
  948. int n_in = kann_dim_in (k);
  949. if (n_in <= 0) {
  950. return luaL_error (L, "invalid inputs count: %d", n_in);
  951. }
  952. if (n_in != vec_len) {
  953. return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
  954. (int)vec_len, n_in);
  955. }
  956. for (gsize i = 0; i < vec_len; i ++) {
  957. lua_rawgeti (L, 2, i + 1);
  958. vec[i] = lua_tonumber (L, -1);
  959. lua_pop (L, 1);
  960. }
  961. i_out = kann_find (k, KANN_F_OUT, 0);
  962. if (i_out <= 0) {
  963. g_free (vec);
  964. return luaL_error (L, "invalid ANN: output layer is missing or is "
  965. "at the input pos");
  966. }
  967. kann_set_batch_size (k, 1);
  968. kann_feed_bind (k, KANN_F_IN, 0, &vec);
  969. kad_eval_at (k->n, k->v, i_out);
  970. gsize outlen = kad_len (k->v[i_out]);
  971. lua_createtable (L, outlen, 0);
  972. for (gsize i = 0; i < outlen; i ++) {
  973. lua_pushnumber (L, k->v[i_out]->x[i]);
  974. lua_rawseti (L, -2, i + 1);
  975. }
  976. g_free (vec);
  977. }
  978. else {
  979. return luaL_error (L, "invalid arguments: rspamd{kann} expected");
  980. }
  981. return 1;
  982. }