You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_kann.c 32KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "lua_tensor.h"
  18. #include "contrib/kann/kann.h"
  19. /***
  20. * @module rspamd_kann
  21. * `rspamd_kann` is a Lua interface to kann library
  22. */
  23. #define KANN_NODE_CLASS rspamd_kann_node_classname
  24. #define KANN_NETWORK_CLASS rspamd_kann_classname
  25. /* Simple macros to define behaviour */
  26. #define KANN_LAYER_DEF(name) static int lua_kann_layer_##name(lua_State *L)
  27. #define KANN_LAYER_INTERFACE(name) \
  28. { \
  29. #name, lua_kann_layer_##name \
  30. }
  31. #define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_##name(lua_State *L)
  32. #define KANN_TRANSFORM_INTERFACE(name) \
  33. { \
  34. #name, lua_kann_transform_##name \
  35. }
  36. #define KANN_LOSS_DEF(name) static int lua_kann_loss_##name(lua_State *L)
  37. #define KANN_LOSS_INTERFACE(name) \
  38. { \
  39. #name, lua_kann_loss_##name \
  40. }
  41. #define KANN_NEW_DEF(name) static int lua_kann_new_##name(lua_State *L)
  42. #define KANN_NEW_INTERFACE(name) \
  43. { \
  44. #name, lua_kann_new_##name \
  45. }
  46. /*
  47. * Forwarded declarations
  48. */
  49. static kad_node_t *lua_check_kann_node(lua_State *L, int pos);
  50. /* Layers */
  51. KANN_LAYER_DEF(input);
  52. KANN_LAYER_DEF(dense);
  53. KANN_LAYER_DEF(layernorm);
  54. KANN_LAYER_DEF(rnn);
  55. KANN_LAYER_DEF(lstm);
  56. KANN_LAYER_DEF(gru);
  57. KANN_LAYER_DEF(conv2d);
  58. KANN_LAYER_DEF(conv1d);
  59. KANN_LAYER_DEF(cost);
  60. static luaL_reg rspamd_kann_layers_f[] = {
  61. KANN_LAYER_INTERFACE(input),
  62. KANN_LAYER_INTERFACE(dense),
  63. KANN_LAYER_INTERFACE(layernorm),
  64. KANN_LAYER_INTERFACE(rnn),
  65. KANN_LAYER_INTERFACE(lstm),
  66. KANN_LAYER_INTERFACE(gru),
  67. KANN_LAYER_INTERFACE(conv2d),
  68. KANN_LAYER_INTERFACE(conv1d),
  69. KANN_LAYER_INTERFACE(cost),
  70. {NULL, NULL},
  71. };
  72. /* Transition and composition functions */
  73. /* General transform */
  74. KANN_TRANSFORM_DEF(add);
  75. KANN_TRANSFORM_DEF(sub);
  76. KANN_TRANSFORM_DEF(mul);
  77. KANN_TRANSFORM_DEF(cmul);
  78. KANN_TRANSFORM_DEF(matmul);
  79. KANN_TRANSFORM_DEF(square);
  80. KANN_TRANSFORM_DEF(sigm);
  81. KANN_TRANSFORM_DEF(tanh);
  82. KANN_TRANSFORM_DEF(relu);
  83. KANN_TRANSFORM_DEF(softmax);
  84. KANN_TRANSFORM_DEF(1minus);
  85. KANN_TRANSFORM_DEF(exp);
  86. KANN_TRANSFORM_DEF(log);
  87. KANN_TRANSFORM_DEF(sin);
  88. static luaL_reg rspamd_kann_transform_f[] = {
  89. KANN_TRANSFORM_INTERFACE(add),
  90. KANN_TRANSFORM_INTERFACE(sub),
  91. KANN_TRANSFORM_INTERFACE(mul),
  92. KANN_TRANSFORM_INTERFACE(cmul),
  93. KANN_TRANSFORM_INTERFACE(matmul),
  94. KANN_TRANSFORM_INTERFACE(square),
  95. KANN_TRANSFORM_INTERFACE(sigm),
  96. KANN_TRANSFORM_INTERFACE(tanh),
  97. KANN_TRANSFORM_INTERFACE(relu),
  98. KANN_TRANSFORM_INTERFACE(softmax),
  99. KANN_TRANSFORM_INTERFACE(1minus),
  100. KANN_TRANSFORM_INTERFACE(exp),
  101. KANN_TRANSFORM_INTERFACE(log),
  102. KANN_TRANSFORM_INTERFACE(sin),
  103. {NULL, NULL},
  104. };
  105. /* Loss functions */
  106. KANN_LOSS_DEF(mse);
  107. KANN_LOSS_DEF(ce_multi);
  108. KANN_LOSS_DEF(ce_bin);
  109. KANN_LOSS_DEF(ce_bin_neg);
  110. KANN_LOSS_DEF(ce_multi_weighted);
  111. static luaL_reg rspamd_kann_loss_f[] = {
  112. KANN_LOSS_INTERFACE(mse),
  113. KANN_LOSS_INTERFACE(ce_multi),
  114. KANN_LOSS_INTERFACE(ce_bin),
  115. KANN_LOSS_INTERFACE(ce_bin_neg),
  116. KANN_LOSS_INTERFACE(ce_multi_weighted),
  117. {NULL, NULL},
  118. };
  119. /* Creation functions */
  120. KANN_NEW_DEF(leaf);
  121. KANN_NEW_DEF(scalar);
  122. KANN_NEW_DEF(weight);
  123. KANN_NEW_DEF(bias);
  124. KANN_NEW_DEF(weight_conv2d);
  125. KANN_NEW_DEF(weight_conv1d);
  126. KANN_NEW_DEF(kann);
  127. static luaL_reg rspamd_kann_new_f[] = {
  128. KANN_NEW_INTERFACE(leaf),
  129. KANN_NEW_INTERFACE(scalar),
  130. KANN_NEW_INTERFACE(weight),
  131. KANN_NEW_INTERFACE(bias),
  132. KANN_NEW_INTERFACE(weight_conv2d),
  133. KANN_NEW_INTERFACE(weight_conv1d),
  134. KANN_NEW_INTERFACE(kann),
  135. {NULL, NULL},
  136. };
  137. LUA_FUNCTION_DEF(kann, load);
  138. LUA_FUNCTION_DEF(kann, destroy);
  139. LUA_FUNCTION_DEF(kann, save);
  140. LUA_FUNCTION_DEF(kann, train1);
  141. LUA_FUNCTION_DEF(kann, apply1);
  142. static luaL_reg rspamd_kann_m[] = {
  143. LUA_INTERFACE_DEF(kann, save),
  144. LUA_INTERFACE_DEF(kann, train1),
  145. LUA_INTERFACE_DEF(kann, apply1),
  146. {"__gc", lua_kann_destroy},
  147. {NULL, NULL},
  148. };
  149. static int
  150. rspamd_kann_table_to_flags(lua_State *L, int table_pos)
  151. {
  152. int result = 0;
  153. lua_pushvalue(L, table_pos);
  154. for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
  155. int fl = lua_tointeger(L, -1);
  156. result |= fl;
  157. }
  158. lua_pop(L, 1);
  159. return result;
  160. }
  161. static int
  162. lua_load_kann(lua_State *L)
  163. {
  164. lua_newtable(L);
  165. /* Flags */
  166. lua_pushstring(L, "flag");
  167. lua_newtable(L);
  168. lua_pushinteger(L, KANN_F_IN);
  169. lua_setfield(L, -2, "in");
  170. lua_pushinteger(L, KANN_F_COST);
  171. lua_setfield(L, -2, "cost");
  172. lua_pushinteger(L, KANN_F_OUT);
  173. lua_setfield(L, -2, "out");
  174. lua_pushinteger(L, KANN_F_TRUTH);
  175. lua_setfield(L, -2, "truth");
  176. lua_settable(L, -3);
  177. /* Cost type */
  178. lua_pushstring(L, "cost");
  179. lua_newtable(L);
  180. /* binary cross-entropy cost, used with sigmoid */
  181. lua_pushinteger(L, KANN_C_CEB);
  182. lua_setfield(L, -2, "ceb");
  183. /* multi-class cross-entropy cost, used with softmax */
  184. lua_pushinteger(L, KANN_C_CEM);
  185. lua_setfield(L, -2, "cem");
  186. /* binary cross-entropy-like cost, used with tanh */
  187. lua_pushinteger(L, KANN_C_CEB_NEG);
  188. lua_setfield(L, -2, "ceb_neg");
  189. lua_pushinteger(L, KANN_C_MSE);
  190. lua_setfield(L, -2, "mse");
  191. lua_settable(L, -3);
  192. /* RNN flag */
  193. lua_pushstring(L, "rnn");
  194. lua_newtable(L);
  195. /* apply layer normalization */
  196. lua_pushinteger(L, KANN_RNN_NORM);
  197. lua_setfield(L, -2, "norm");
  198. /* take the initial hidden values as variables */
  199. lua_pushinteger(L, KANN_RNN_VAR_H0);
  200. lua_setfield(L, -2, "var_h0");
  201. lua_settable(L, -3);
  202. /* Layers */
  203. lua_pushstring(L, "layer");
  204. lua_newtable(L);
  205. luaL_register(L, NULL, rspamd_kann_layers_f);
  206. lua_settable(L, -3);
  207. /* Transforms */
  208. lua_pushstring(L, "transform");
  209. lua_newtable(L);
  210. luaL_register(L, NULL, rspamd_kann_transform_f);
  211. lua_settable(L, -3);
  212. /* Cost */
  213. lua_pushstring(L, "loss");
  214. lua_newtable(L);
  215. luaL_register(L, NULL, rspamd_kann_loss_f);
  216. lua_settable(L, -3);
  217. /* Create functions */
  218. lua_pushstring(L, "new");
  219. lua_newtable(L);
  220. luaL_register(L, NULL, rspamd_kann_new_f);
  221. lua_settable(L, -3);
  222. /* Load ann from memory or file */
  223. lua_pushstring(L, "load");
  224. lua_pushcfunction(L, lua_kann_load);
  225. lua_settable(L, -3);
  226. return 1;
  227. }
  228. static kad_node_t *
  229. lua_check_kann_node(lua_State *L, int pos)
  230. {
  231. void *ud = rspamd_lua_check_udata(L, pos, KANN_NODE_CLASS);
  232. luaL_argcheck(L, ud != NULL, pos, "'kann_node' expected");
  233. return ud ? *((kad_node_t **) ud) : NULL;
  234. }
  235. static kann_t *
  236. lua_check_kann(lua_State *L, int pos)
  237. {
  238. void *ud = rspamd_lua_check_udata(L, pos, KANN_NETWORK_CLASS);
  239. luaL_argcheck(L, ud != NULL, pos, "'kann' expected");
  240. return ud ? *((kann_t **) ud) : NULL;
  241. }
  242. void luaopen_kann(lua_State *L)
  243. {
  244. /* Metatables */
  245. rspamd_lua_new_class(L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
  246. lua_pop(L, 1); /* No need in metatable... */
  247. rspamd_lua_new_class(L, KANN_NETWORK_CLASS, rspamd_kann_m);
  248. lua_pop(L, 1); /* No need in metatable... */
  249. rspamd_lua_add_preload(L, "rspamd_kann", lua_load_kann);
  250. lua_settop(L, 0);
  251. }
  252. /* Layers implementation */
  253. #define PUSH_KAD_NODE(n) \
  254. do { \
  255. kad_node_t **pt; \
  256. pt = lua_newuserdata(L, sizeof(kad_node_t *)); \
  257. *pt = (n); \
  258. rspamd_lua_setclass(L, KANN_NODE_CLASS, -1); \
  259. } while (0)
  260. #define PUSH_KAN_NETWORK(n) \
  261. do { \
  262. kann_t **pn; \
  263. pn = lua_newuserdata(L, sizeof(kann_t *)); \
  264. *pn = (n); \
  265. rspamd_lua_setclass(L, KANN_NETWORK_CLASS, -1); \
  266. } while (0)
  267. #define PROCESS_KAD_FLAGS(n, pos) \
  268. do { \
  269. int fl = 0; \
  270. if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags(L, (pos)); } \
  271. else if (lua_type(L, (pos)) == LUA_TNUMBER) { \
  272. fl = lua_tointeger(L, (pos)); \
  273. } \
  274. (n)->ext_flag |= fl; \
  275. } while (0)
  276. /***
  277. * @function kann.layer.input(ninputs[, flags])
  278. * Creates an input layer for ANN
  279. * @param {int} ninputs number of inputs
  280. * @param {table|int} flags optional flags
  281. * @return {kann_node} kann node object (should be used to combine ANN)
  282. */
  283. static int
  284. lua_kann_layer_input(lua_State *L)
  285. {
  286. int nnodes = luaL_checkinteger(L, 1);
  287. if (nnodes > 0) {
  288. kad_node_t *t;
  289. t = kann_layer_input(nnodes);
  290. PROCESS_KAD_FLAGS(t, 2);
  291. PUSH_KAD_NODE(t);
  292. }
  293. else {
  294. return luaL_error(L, "invalid arguments, nnodes required");
  295. }
  296. return 1;
  297. }
  298. /***
  299. * @function kann.layer.dense(in, ninputs[, flags])
  300. * Creates a dense layer (e.g. for hidden layer)
  301. * @param {kann_node} in kann node
  302. * @param {int} ninputs number of dense nodes
  303. * @param {table|int} flags optional flags
  304. * @return {kann_node} kann node object (should be used to combine ANN)
  305. */
  306. static int
  307. lua_kann_layer_dense(lua_State *L)
  308. {
  309. kad_node_t *in = lua_check_kann_node(L, 1);
  310. int nnodes = luaL_checkinteger(L, 2);
  311. if (in != NULL && nnodes > 0) {
  312. kad_node_t *t;
  313. t = kann_layer_dense(in, nnodes);
  314. PROCESS_KAD_FLAGS(t, 3);
  315. PUSH_KAD_NODE(t);
  316. }
  317. else {
  318. return luaL_error(L, "invalid arguments, input + nnodes required");
  319. }
  320. return 1;
  321. }
  322. /***
  323. * @function kann.layer.dropout(in, ratio[, flags])
  324. * Creates a dropout layer
  325. * @param {kann_node} in kann node
  326. * @param {float} ratio drop ratio
  327. * @param {table|int} flags optional flags
  328. * @return {kann_node} kann node object (should be used to combine ANN)
  329. */
  330. static int
  331. lua_kann_layer_layerdropout(lua_State *L)
  332. {
  333. kad_node_t *in = lua_check_kann_node(L, 1);
  334. double r = luaL_checknumber(L, 2);
  335. if (in != NULL) {
  336. kad_node_t *t;
  337. t = kann_layer_dropout(in, r);
  338. PROCESS_KAD_FLAGS(t, 3);
  339. PUSH_KAD_NODE(t);
  340. }
  341. else {
  342. return luaL_error(L, "invalid arguments, input + rate required");
  343. }
  344. return 1;
  345. }
  346. /***
  347. * @function kann.layer.dropout(in [, flags])
  348. * Creates a normalisation layer
  349. * @param {kann_node} in kann node
  350. * @param {table|int} flags optional flags
  351. * @return {kann_node} kann node object (should be used to combine ANN)
  352. */
  353. static int
  354. lua_kann_layer_layernorm(lua_State *L)
  355. {
  356. kad_node_t *in = lua_check_kann_node(L, 1);
  357. if (in != NULL) {
  358. kad_node_t *t;
  359. t = kann_layer_layernorm(in);
  360. PROCESS_KAD_FLAGS(t, 2);
  361. PUSH_KAD_NODE(t);
  362. }
  363. else {
  364. return luaL_error(L, "invalid arguments, input required");
  365. }
  366. return 1;
  367. }
  368. /***
  369. * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
  370. * Creates a recursive NN layer
  371. * @param {kann_node} in kann node
  372. * @param {int} nnodes number of cells
  373. * @param {int} rnnflags rnn flags
  374. * @param {table|int} flags optional flags
  375. * @return {kann_node} kann node object (should be used to combine ANN)
  376. */
  377. static int
  378. lua_kann_layer_rnn(lua_State *L)
  379. {
  380. kad_node_t *in = lua_check_kann_node(L, 1);
  381. int nnodes = luaL_checkinteger(L, 2);
  382. int rnnflags = 0;
  383. if (in != NULL && nnodes > 0) {
  384. kad_node_t *t;
  385. if (lua_type(L, 3) == LUA_TNUMBER) {
  386. rnnflags = lua_tointeger(L, 3);
  387. }
  388. t = kann_layer_rnn(in, nnodes, rnnflags);
  389. PROCESS_KAD_FLAGS(t, 4);
  390. PUSH_KAD_NODE(t);
  391. }
  392. else {
  393. return luaL_error(L, "invalid arguments, input + nnodes required");
  394. }
  395. return 1;
  396. }
  397. /***
  398. * @function kann.layer.lstm(in, nnodes[, rnn_flags, [, flags]])
  399. * Creates a recursive NN layer using LSTM cells
  400. * @param {kann_node} in kann node
  401. * @param {int} nnodes number of cells
  402. * @param {int} rnnflags rnn flags
  403. * @param {table|int} flags optional flags
  404. * @return {kann_node} kann node object (should be used to combine ANN)
  405. */
  406. static int
  407. lua_kann_layer_lstm(lua_State *L)
  408. {
  409. kad_node_t *in = lua_check_kann_node(L, 1);
  410. int nnodes = luaL_checkinteger(L, 2);
  411. int rnnflags = 0;
  412. if (in != NULL && nnodes > 0) {
  413. kad_node_t *t;
  414. if (lua_type(L, 3) == LUA_TNUMBER) {
  415. rnnflags = lua_tointeger(L, 3);
  416. }
  417. t = kann_layer_lstm(in, nnodes, rnnflags);
  418. PROCESS_KAD_FLAGS(t, 4);
  419. PUSH_KAD_NODE(t);
  420. }
  421. else {
  422. return luaL_error(L, "invalid arguments, input + nnodes required");
  423. }
  424. return 1;
  425. }
  426. /***
  427. * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
  428. * Creates a recursive NN layer using GRU cells
  429. * @param {kann_node} in kann node
  430. * @param {int} nnodes number of cells
  431. * @param {int} rnnflags rnn flags
  432. * @param {table|int} flags optional flags
  433. * @return {kann_node} kann node object (should be used to combine ANN)
  434. */
  435. static int
  436. lua_kann_layer_gru(lua_State *L)
  437. {
  438. kad_node_t *in = lua_check_kann_node(L, 1);
  439. int nnodes = luaL_checkinteger(L, 2);
  440. int rnnflags = 0;
  441. if (in != NULL && nnodes > 0) {
  442. kad_node_t *t;
  443. if (lua_type(L, 3) == LUA_TNUMBER) {
  444. rnnflags = lua_tointeger(L, 3);
  445. }
  446. t = kann_layer_gru(in, nnodes, rnnflags);
  447. PROCESS_KAD_FLAGS(t, 4);
  448. PUSH_KAD_NODE(t);
  449. }
  450. else {
  451. return luaL_error(L, "invalid arguments, input + nnodes required");
  452. }
  453. return 1;
  454. }
  455. /***
  456. * @function kann.layer.conv2d(in, n_flt, k_rows, k_cols, stride_rows, stride_cols, pad_rows, pad_columns[, flags])
  457. * Creates a 2D convolution layer
  458. * @param {kann_node} in kann node
  459. * @param {int} n_flt number of filters
  460. * @param {int} k_rows kernel rows
  461. * @param {int} k_cols kernel columns
  462. * @param {int} stride_rows stride rows
  463. * @param {int} stride_cols stride columns
  464. * @param {int} pad_rows padding rows
  465. * @param {int} pad_columns padding columns
  466. * @param {table|int} flags optional flags
  467. * @return {kann_node} kann node object (should be used to combine ANN)
  468. */
  469. static int
  470. lua_kann_layer_conv2d(lua_State *L)
  471. {
  472. kad_node_t *in = lua_check_kann_node(L, 1);
  473. int n_flt = luaL_checkinteger(L, 2);
  474. int k_rows = luaL_checkinteger(L, 3);
  475. int k_cols = luaL_checkinteger(L, 4);
  476. int stride_r = luaL_checkinteger(L, 5);
  477. int stride_c = luaL_checkinteger(L, 6);
  478. int pad_r = luaL_checkinteger(L, 7);
  479. int pad_c = luaL_checkinteger(L, 8);
  480. if (in != NULL) {
  481. kad_node_t *t;
  482. t = kann_layer_conv2d(in, n_flt, k_rows, k_cols, stride_r, stride_c,
  483. pad_r, pad_c);
  484. PROCESS_KAD_FLAGS(t, 9);
  485. PUSH_KAD_NODE(t);
  486. }
  487. else {
  488. return luaL_error(L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
  489. }
  490. return 1;
  491. }
  492. /***
  493. * @function kann.layer.conv1d(in, n_flt, kern_size, stride_size, pad_size[, flags])
  494. * Creates 1D convolution layer
  495. * @param {kann_node} in kann node
  496. * @param {int} n_flt number of filters
  497. * @param {int} kern_size kernel rows
  498. * @param {int} stride_size stride rows
  499. * @param {int} pad_size padding rows
  500. * @param {table|int} flags optional flags
  501. * @return {kann_node} kann node object (should be used to combine ANN)
  502. */
  503. static int
  504. lua_kann_layer_conv1d(lua_State *L)
  505. {
  506. kad_node_t *in = lua_check_kann_node(L, 1);
  507. int n_flt = luaL_checkinteger(L, 2);
  508. int k_size = luaL_checkinteger(L, 3);
  509. int stride = luaL_checkinteger(L, 4);
  510. int pad = luaL_checkinteger(L, 5);
  511. if (in != NULL) {
  512. kad_node_t *t;
  513. t = kann_layer_conv1d(in, n_flt, k_size, stride, pad);
  514. PROCESS_KAD_FLAGS(t, 6);
  515. PUSH_KAD_NODE(t);
  516. }
  517. else {
  518. return luaL_error(L, "invalid arguments, input, nflt, k, stride, pad required");
  519. }
  520. return 1;
  521. }
  522. /***
  523. * @function kann.layer.cost(in, nout, cost_type[, flags])
  524. * Creates 1D convolution layer
  525. * @param {kann_node} in kann node
  526. * @param {int} nout number of outputs
  527. * @param {int} cost_type see kann.cost table
  528. * @param {table|int} flags optional flags
  529. * @return {kann_node} kann node object (should be used to combine ANN)
  530. */
  531. static int
  532. lua_kann_layer_cost(lua_State *L)
  533. {
  534. kad_node_t *in = lua_check_kann_node(L, 1);
  535. int nout = luaL_checkinteger(L, 2);
  536. int cost_type = luaL_checkinteger(L, 3);
  537. if (in != NULL && nout > 0) {
  538. kad_node_t *t;
  539. t = kann_layer_cost(in, nout, cost_type);
  540. PROCESS_KAD_FLAGS(t, 4);
  541. PUSH_KAD_NODE(t);
  542. }
  543. else {
  544. return luaL_error(L, "invalid arguments, input, nout and cost_type are required");
  545. }
  546. return 1;
  547. }
  548. /* Generic helpers */
  549. static int
  550. lua_kann_call_unary_function(lua_State *L, const char *name,
  551. kad_node_t *(*func)(kad_node_t *) )
  552. {
  553. kad_node_t *in = lua_check_kann_node(L, 1);
  554. if (in != NULL) {
  555. kad_node_t *t;
  556. t = func(in);
  557. PUSH_KAD_NODE(t);
  558. }
  559. else {
  560. return luaL_error(L, "invalid arguments for %s, input required", name);
  561. }
  562. return 1;
  563. }
  564. static int
  565. lua_kann_call_binary_function(lua_State *L, const char *name,
  566. kad_node_t *(*func)(kad_node_t *, kad_node_t *) )
  567. {
  568. kad_node_t *x = lua_check_kann_node(L, 1);
  569. kad_node_t *y = lua_check_kann_node(L, 2);
  570. if (x != NULL && y != NULL) {
  571. kad_node_t *t;
  572. t = func(x, y);
  573. PUSH_KAD_NODE(t);
  574. }
  575. else {
  576. return luaL_error(L, "invalid arguments for %s, 2 inputs required", name);
  577. }
  578. return 1;
  579. }
  580. #define LUA_UNARY_TRANSFORM_FUNC_IMPL(name) \
  581. static int lua_kann_transform_##name(lua_State *L) \
  582. { \
  583. return lua_kann_call_unary_function(L, #name, kad_##name); \
  584. }
  585. #define LUA_BINARY_TRANSFORM_FUNC_IMPL(name) \
  586. static int lua_kann_transform_##name(lua_State *L) \
  587. { \
  588. return lua_kann_call_binary_function(L, #name, kad_##name); \
  589. }
  590. #define LUA_LOSS_FUNC_IMPL(name) \
  591. static int lua_kann_loss_##name(lua_State *L) \
  592. { \
  593. return lua_kann_call_binary_function(L, #name, kad_##name); \
  594. }
  595. /* Transform functions registered via macro helpers */
  596. LUA_BINARY_TRANSFORM_FUNC_IMPL(add)
  597. LUA_BINARY_TRANSFORM_FUNC_IMPL(sub)
  598. LUA_BINARY_TRANSFORM_FUNC_IMPL(mul)
  599. LUA_BINARY_TRANSFORM_FUNC_IMPL(cmul)
  600. LUA_BINARY_TRANSFORM_FUNC_IMPL(matmul)
  601. LUA_UNARY_TRANSFORM_FUNC_IMPL(square)
  602. LUA_UNARY_TRANSFORM_FUNC_IMPL(sigm)
  603. LUA_UNARY_TRANSFORM_FUNC_IMPL(tanh)
  604. LUA_UNARY_TRANSFORM_FUNC_IMPL(relu)
  605. LUA_UNARY_TRANSFORM_FUNC_IMPL(softmax)
  606. LUA_UNARY_TRANSFORM_FUNC_IMPL(1minus)
  607. LUA_UNARY_TRANSFORM_FUNC_IMPL(exp)
  608. LUA_UNARY_TRANSFORM_FUNC_IMPL(log)
  609. LUA_UNARY_TRANSFORM_FUNC_IMPL(sin)
  610. /* Generic cost functions */
  611. LUA_LOSS_FUNC_IMPL(mse)
  612. LUA_LOSS_FUNC_IMPL(ce_multi)
  613. LUA_LOSS_FUNC_IMPL(ce_bin)
  614. LUA_LOSS_FUNC_IMPL(ce_bin_neg)
  615. /* The only case of ternary weight function */
  616. static int
  617. lua_kann_loss_ce_multi_weighted(lua_State *L)
  618. {
  619. kad_node_t *pred = lua_check_kann_node(L, 1);
  620. kad_node_t *truth = lua_check_kann_node(L, 2);
  621. kad_node_t *weight = lua_check_kann_node(L, 3);
  622. if (pred != NULL && truth != NULL && weight != NULL) {
  623. kad_node_t *t;
  624. t = kad_ce_multi_weighted(pred, truth, weight);
  625. PUSH_KAD_NODE(t);
  626. }
  627. else {
  628. return luaL_error(L, "invalid arguments for ce_multi_weighted, 3 inputs required");
  629. }
  630. return 1;
  631. }
  632. /* Creation functions */
  633. static int
  634. lua_kann_new_scalar(lua_State *L)
  635. {
  636. int flag = luaL_checkinteger(L, 1);
  637. double x = luaL_checknumber(L, 2);
  638. kad_node_t *t;
  639. t = kann_new_scalar(flag, x);
  640. PROCESS_KAD_FLAGS(t, 3);
  641. PUSH_KAD_NODE(t);
  642. return 1;
  643. }
  644. static int
  645. lua_kann_new_weight(lua_State *L)
  646. {
  647. int nrow = luaL_checkinteger(L, 1);
  648. int ncol = luaL_checkinteger(L, 2);
  649. kad_node_t *t;
  650. t = kann_new_weight(nrow, ncol);
  651. PROCESS_KAD_FLAGS(t, 3);
  652. PUSH_KAD_NODE(t);
  653. return 1;
  654. }
  655. static int
  656. lua_kann_new_bias(lua_State *L)
  657. {
  658. int n = luaL_checkinteger(L, 1);
  659. kad_node_t *t;
  660. t = kann_new_bias(n);
  661. PROCESS_KAD_FLAGS(t, 2);
  662. PUSH_KAD_NODE(t);
  663. return 1;
  664. }
  665. static int
  666. lua_kann_new_weight_conv2d(lua_State *L)
  667. {
  668. int nout = luaL_checkinteger(L, 1);
  669. int nin = luaL_checkinteger(L, 2);
  670. int krow = luaL_checkinteger(L, 3);
  671. int kcol = luaL_checkinteger(L, 4);
  672. kad_node_t *t;
  673. t = kann_new_weight_conv2d(nout, nin, krow, kcol);
  674. PROCESS_KAD_FLAGS(t, 5);
  675. PUSH_KAD_NODE(t);
  676. return 1;
  677. }
  678. static int
  679. lua_kann_new_weight_conv1d(lua_State *L)
  680. {
  681. int nout = luaL_checkinteger(L, 1);
  682. int nin = luaL_checkinteger(L, 2);
  683. int klen = luaL_checkinteger(L, 3);
  684. kad_node_t *t;
  685. t = kann_new_weight_conv1d(nout, nin, klen);
  686. PROCESS_KAD_FLAGS(t, 4);
  687. PUSH_KAD_NODE(t);
  688. return 1;
  689. }
  690. static int
  691. lua_kann_new_leaf(lua_State *L)
  692. {
  693. int dim = luaL_checkinteger(L, 1), i, *ar;
  694. kad_node_t *t;
  695. if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable(L, 2)) {
  696. ar = g_new0(int, KAD_MAX_DIM);
  697. for (i = 0; i < dim; i++) {
  698. lua_rawgeti(L, 2, i + 1);
  699. ar[i] = lua_tointeger(L, -1);
  700. lua_pop(L, 1);
  701. }
  702. t = kann_new_leaf_array(NULL, NULL, 0, 0.0, dim, ar);
  703. PROCESS_KAD_FLAGS(t, 3);
  704. PUSH_KAD_NODE(t);
  705. g_free(ar);
  706. }
  707. else {
  708. return luaL_error(L, "invalid arguments for new.leaf, "
  709. "dim and vector of elements are required");
  710. }
  711. return 1;
  712. }
  713. static int
  714. lua_kann_new_kann(lua_State *L)
  715. {
  716. kad_node_t *cost = lua_check_kann_node(L, 1);
  717. kann_t *k;
  718. if (cost) {
  719. k = kann_new(cost, 0);
  720. PUSH_KAN_NETWORK(k);
  721. }
  722. else {
  723. return luaL_error(L, "invalid arguments for new.kann, "
  724. "cost node is required");
  725. }
  726. return 1;
  727. }
  728. static int
  729. lua_kann_destroy(lua_State *L)
  730. {
  731. kann_t *k = lua_check_kann(L, 1);
  732. kann_delete(k);
  733. return 0;
  734. }
  735. static int
  736. lua_kann_save(lua_State *L)
  737. {
  738. kann_t *k = lua_check_kann(L, 1);
  739. if (k) {
  740. if (lua_istable(L, 2)) {
  741. lua_getfield(L, 2, "filename");
  742. if (lua_isstring(L, -1)) {
  743. const char *fname = lua_tostring(L, -1);
  744. FILE *f;
  745. f = fopen(fname, "w");
  746. if (!f) {
  747. lua_pop(L, 1);
  748. return luaL_error(L, "cannot open %s for writing: %s",
  749. fname, strerror(errno));
  750. }
  751. kann_save_fp(f, k);
  752. fclose(f);
  753. lua_pushboolean(L, true);
  754. }
  755. else {
  756. lua_pop(L, 1);
  757. return luaL_error(L, "invalid arguments: missing filename");
  758. }
  759. lua_pop(L, 1);
  760. }
  761. else {
  762. /* Save to Rspamd text */
  763. #ifndef HAVE_OPENMEMSTREAM
  764. return luaL_error(L, "no support of saving to memory on your system");
  765. #endif
  766. FILE *f;
  767. char *buf = NULL;
  768. size_t buflen;
  769. struct rspamd_lua_text *t;
  770. f = open_memstream(&buf, &buflen);
  771. g_assert(f != NULL);
  772. kann_save_fp(f, k);
  773. fclose(f);
  774. t = lua_newuserdata(L, sizeof(*t));
  775. rspamd_lua_setclass(L, rspamd_text_classname, -1);
  776. t->flags = RSPAMD_TEXT_FLAG_OWN;
  777. t->start = (const char *) buf;
  778. t->len = buflen;
  779. }
  780. }
  781. else {
  782. return luaL_error(L, "invalid arguments");
  783. }
  784. return 1;
  785. }
  786. static int
  787. lua_kann_load(lua_State *L)
  788. {
  789. kann_t *k;
  790. FILE *f = NULL;
  791. if (lua_istable(L, 1)) {
  792. lua_getfield(L, 2, "filename");
  793. if (lua_isstring(L, -1)) {
  794. const char *fname = lua_tostring(L, -1);
  795. f = fopen(fname, "rb");
  796. }
  797. else {
  798. lua_pop(L, 1);
  799. return luaL_error(L, "invalid arguments: missing filename");
  800. }
  801. lua_pop(L, 1);
  802. }
  803. else if (lua_isstring(L, 1)) {
  804. gsize dlen;
  805. const char *data;
  806. data = lua_tolstring(L, 1, &dlen);
  807. #ifndef HAVE_FMEMOPEN
  808. return luaL_error(L, "no support of loading from memory on your system");
  809. #endif
  810. f = fmemopen((void *) data, dlen, "rb");
  811. }
  812. else if (lua_isuserdata(L, 1)) {
  813. struct rspamd_lua_text *t;
  814. t = lua_check_text(L, 1);
  815. if (!t) {
  816. return luaL_error(L, "invalid arguments");
  817. }
  818. #ifndef HAVE_FMEMOPEN
  819. return luaL_error(L, "no support of loading from memory on your system");
  820. #endif
  821. f = fmemopen((void *) t->start, t->len, "rb");
  822. }
  823. if (f == NULL) {
  824. return luaL_error(L, "invalid arguments or cannot open file");
  825. }
  826. k = kann_load_fp(f);
  827. fclose(f);
  828. if (k == NULL) {
  829. lua_pushnil(L);
  830. }
  831. else {
  832. PUSH_KAN_NETWORK(k);
  833. }
  834. return 1;
  835. }
  836. struct rspamd_kann_train_cbdata {
  837. lua_State *L;
  838. kann_t *k;
  839. int cbref;
  840. };
  841. static void
  842. lua_kann_train_cb(int iter, float train_cost, float val_cost, void *ud)
  843. {
  844. struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *) ud;
  845. if (cbd->cbref != -1) {
  846. int err_idx;
  847. lua_State *L = cbd->L;
  848. lua_pushcfunction(L, &rspamd_lua_traceback);
  849. err_idx = lua_gettop(L);
  850. lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref);
  851. lua_pushinteger(L, iter);
  852. lua_pushnumber(L, train_cost);
  853. lua_pushnumber(L, val_cost);
  854. if (lua_pcall(L, 3, 0, err_idx) != 0) {
  855. msg_err("cannot run lua train callback: %s",
  856. lua_tostring(L, -1));
  857. }
  858. lua_settop(L, err_idx - 1);
  859. }
  860. }
  861. #define FREE_VEC(a, n) \
  862. do { \
  863. for (int i = 0; i < (n); i++) g_free((a)[i]); \
  864. g_free(a); \
  865. } while (0)
  866. static int
  867. lua_kann_train1(lua_State *L)
  868. {
  869. kann_t *k = lua_check_kann(L, 1);
  870. struct rspamd_lua_tensor *pca = NULL;
  871. /* Default train params */
  872. double lr = 0.001;
  873. int64_t mini_size = 64;
  874. int64_t max_epoch = 25;
  875. int64_t max_drop_streak = 10;
  876. double frac_val = 0.1;
  877. int cbref = -1;
  878. if (k && lua_istable(L, 2) && lua_istable(L, 3)) {
  879. int n = rspamd_lua_table_size(L, 2);
  880. int n_in = kann_dim_in(k);
  881. int n_out = kann_dim_out(k);
  882. if (n_in <= 0) {
  883. return luaL_error(L, "invalid inputs count: %d", n_in);
  884. }
  885. if (n_out <= 0) {
  886. return luaL_error(L, "invalid outputs count: %d", n_out);
  887. }
  888. if (n != rspamd_lua_table_size(L, 3) || n == 0) {
  889. return luaL_error(L, "invalid dimensions: outputs size must be "
  890. "equal to inputs and non zero");
  891. }
  892. if (lua_istable(L, 4)) {
  893. GError *err = NULL;
  894. if (!rspamd_lua_parse_table_arguments(L, 4, &err,
  895. RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
  896. "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
  897. &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
  898. n = luaL_error(L, "invalid params: %s",
  899. err ? err->message : "unknown error");
  900. g_error_free(err);
  901. return n;
  902. }
  903. }
  904. if (pca) {
  905. /* Check pca matrix validity */
  906. if (pca->ndims != 2) {
  907. return luaL_error(L, "invalid pca tensor: matrix expected, got a row");
  908. }
  909. if (pca->dim[0] != n_in) {
  910. return luaL_error(L, "invalid pca tensor: "
  911. "matrix must have %d rows and it has %d rows instead",
  912. n_in, pca->dim[0]);
  913. }
  914. }
  915. float **x, **y, *tmp_row = NULL;
  916. /* Fill vectors row by row */
  917. x = (float **) g_malloc0(sizeof(float *) * n);
  918. y = (float **) g_malloc0(sizeof(float *) * n);
  919. if (pca) {
  920. tmp_row = g_malloc(sizeof(float) * pca->dim[1]);
  921. }
  922. for (int s = 0; s < n; s++) {
  923. /* Inputs */
  924. lua_rawgeti(L, 2, s + 1);
  925. x[s] = (float *) g_malloc(sizeof(float) * n_in);
  926. if (pca == NULL) {
  927. if (rspamd_lua_table_size(L, -1) != n_in) {
  928. FREE_VEC(x, n);
  929. FREE_VEC(y, n);
  930. n = luaL_error(L, "invalid params at pos %d: "
  931. "bad input dimension %d; %d expected",
  932. s + 1,
  933. (int) rspamd_lua_table_size(L, -1),
  934. n_in);
  935. lua_pop(L, 1);
  936. return n;
  937. }
  938. for (int i = 0; i < n_in; i++) {
  939. lua_rawgeti(L, -1, i + 1);
  940. x[s][i] = lua_tonumber(L, -1);
  941. lua_pop(L, 1);
  942. }
  943. }
  944. else {
  945. if (rspamd_lua_table_size(L, -1) != pca->dim[1]) {
  946. FREE_VEC(x, n);
  947. FREE_VEC(y, n);
  948. g_free(tmp_row);
  949. n = luaL_error(L, "(pca on) invalid params at pos %d: "
  950. "bad input dimension %d; %d expected",
  951. s + 1,
  952. (int) rspamd_lua_table_size(L, -1),
  953. pca->dim[1]);
  954. lua_pop(L, 1);
  955. return n;
  956. }
  957. for (int i = 0; i < pca->dim[1]; i++) {
  958. lua_rawgeti(L, -1, i + 1);
  959. tmp_row[i] = lua_tonumber(L, -1);
  960. lua_pop(L, 1);
  961. }
  962. kad_sgemm_simple(0, 1, 1, n_in,
  963. pca->dim[1], tmp_row, pca->data,
  964. x[s]);
  965. }
  966. lua_pop(L, 1);
  967. /* Outputs */
  968. y[s] = (float *) g_malloc(sizeof(float) * n_out);
  969. lua_rawgeti(L, 3, s + 1);
  970. if (rspamd_lua_table_size(L, -1) != n_out) {
  971. FREE_VEC(x, n);
  972. FREE_VEC(y, n);
  973. g_free(tmp_row);
  974. n = luaL_error(L, "invalid params at pos %d: "
  975. "bad output dimension %d; "
  976. "%d expected",
  977. s + 1,
  978. (int) rspamd_lua_table_size(L, -1),
  979. n_out);
  980. lua_pop(L, 1);
  981. return n;
  982. }
  983. for (int i = 0; i < n_out; i++) {
  984. lua_rawgeti(L, -1, i + 1);
  985. y[s][i] = lua_tonumber(L, -1);
  986. lua_pop(L, 1);
  987. }
  988. lua_pop(L, 1);
  989. }
  990. struct rspamd_kann_train_cbdata cbd;
  991. cbd.cbref = cbref;
  992. cbd.k = k;
  993. cbd.L = L;
  994. int niters = kann_train_fnn1(k, lr,
  995. mini_size, max_epoch, max_drop_streak,
  996. frac_val, n, x, y, lua_kann_train_cb, &cbd);
  997. lua_pushinteger(L, niters);
  998. FREE_VEC(x, n);
  999. FREE_VEC(y, n);
  1000. g_free(tmp_row);
  1001. }
  1002. else {
  1003. return luaL_error(L, "invalid arguments: kann, inputs, outputs and"
  1004. " optional params are expected");
  1005. }
  1006. return 1;
  1007. }
  1008. static int
  1009. lua_kann_apply1(lua_State *L)
  1010. {
  1011. kann_t *k = lua_check_kann(L, 1);
  1012. struct rspamd_lua_tensor *pca = NULL;
  1013. if (k) {
  1014. if (lua_istable(L, 2)) {
  1015. gsize vec_len = rspamd_lua_table_size(L, 2);
  1016. float *vec = (float *) g_malloc(sizeof(float) * vec_len),
  1017. *pca_out = NULL;
  1018. int i_out;
  1019. int n_in = kann_dim_in(k);
  1020. if (n_in <= 0) {
  1021. g_free(vec);
  1022. return luaL_error(L, "invalid inputs count: %d", n_in);
  1023. }
  1024. if (lua_isuserdata(L, 3)) {
  1025. pca = lua_check_tensor(L, 3);
  1026. if (pca) {
  1027. if (pca->ndims != 2) {
  1028. g_free(vec);
  1029. return luaL_error(L, "invalid pca tensor: matrix expected, got a row");
  1030. }
  1031. if (pca->dim[0] != n_in) {
  1032. g_free(vec);
  1033. return luaL_error(L, "invalid pca tensor: "
  1034. "matrix must have %d rows and it has %d rows instead",
  1035. n_in, pca->dim[0]);
  1036. }
  1037. }
  1038. else {
  1039. g_free(vec);
  1040. return luaL_error(L, "invalid params: pca matrix expected");
  1041. }
  1042. }
  1043. else {
  1044. if (n_in != vec_len) {
  1045. g_free(vec);
  1046. return luaL_error(L, "invalid params: bad input dimension %d; %d expected",
  1047. (int) vec_len, n_in);
  1048. }
  1049. }
  1050. for (gsize i = 0; i < vec_len; i++) {
  1051. lua_rawgeti(L, 2, i + 1);
  1052. vec[i] = lua_tonumber(L, -1);
  1053. lua_pop(L, 1);
  1054. }
  1055. i_out = kann_find(k, KANN_F_OUT, 0);
  1056. if (i_out <= 0) {
  1057. g_free(vec);
  1058. return luaL_error(L, "invalid ANN: output layer is missing or is "
  1059. "at the input pos");
  1060. }
  1061. kann_set_batch_size(k, 1);
  1062. if (pca) {
  1063. pca_out = g_malloc(sizeof(float) * n_in);
  1064. kad_sgemm_simple(0, 1, 1, n_in,
  1065. vec_len, vec, pca->data,
  1066. pca_out);
  1067. kann_feed_bind(k, KANN_F_IN, 0, &pca_out);
  1068. }
  1069. else {
  1070. kann_feed_bind(k, KANN_F_IN, 0, &vec);
  1071. }
  1072. kad_eval_at(k->n, k->v, i_out);
  1073. gsize outlen = kad_len(k->v[i_out]);
  1074. lua_createtable(L, outlen, 0);
  1075. for (gsize i = 0; i < outlen; i++) {
  1076. lua_pushnumber(L, k->v[i_out]->x[i]);
  1077. lua_rawseti(L, -2, i + 1);
  1078. }
  1079. g_free(vec);
  1080. g_free(pca_out);
  1081. }
  1082. else if (lua_isuserdata(L, 2)) {
  1083. struct rspamd_lua_tensor *t = lua_check_tensor(L, 2);
  1084. if (t && t->ndims == 1) {
  1085. int i_out;
  1086. int n_in = kann_dim_in(k);
  1087. if (n_in != t->dim[0]) {
  1088. return luaL_error(L, "invalid params: bad input dimension %d; %d expected",
  1089. (int) t->dim[0], n_in);
  1090. }
  1091. i_out = kann_find(k, KANN_F_OUT, 0);
  1092. if (i_out <= 0) {
  1093. return luaL_error(L, "invalid ANN: output layer is missing or is "
  1094. "at the input pos");
  1095. }
  1096. kann_set_batch_size(k, 1);
  1097. kann_feed_bind(k, KANN_F_IN, 0, &t->data);
  1098. kad_eval_at(k->n, k->v, i_out);
  1099. int outlen = kad_len(k->v[i_out]);
  1100. struct rspamd_lua_tensor *out;
  1101. out = lua_newtensor(L, 1, &outlen, false, false);
  1102. /* Ensure that kann and tensor have the same understanding of floats */
  1103. G_STATIC_ASSERT(sizeof(float) == sizeof(rspamd_tensor_num_t));
  1104. memcpy(out->data, k->v[i_out]->x, outlen * sizeof(float));
  1105. }
  1106. else {
  1107. return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected");
  1108. }
  1109. }
  1110. else {
  1111. return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected");
  1112. }
  1113. }
  1114. else {
  1115. return luaL_error(L, "invalid arguments: rspamd{kann} expected");
  1116. }
  1117. return 1;
  1118. }