You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_tensor.c 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. /*
  2. * Copyright 2024 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "lua_tensor.h"
  18. #include "contrib/kann/kautodiff.h"
  19. #include "blas-config.h"
  20. /***
  21. * @module rspamd_tensor
  22. * `rspamd_tensor` is a simple Lua library to abstract matrices and vectors
  23. * Internally, they are represented as arrays of float variables
  24. * So far, merely 1D and 2D tensors are supported
  25. */
  26. LUA_FUNCTION_DEF(tensor, load);
  27. LUA_FUNCTION_DEF(tensor, save);
  28. LUA_FUNCTION_DEF(tensor, new);
  29. LUA_FUNCTION_DEF(tensor, fromtable);
  30. LUA_FUNCTION_DEF(tensor, destroy);
  31. LUA_FUNCTION_DEF(tensor, mul);
  32. LUA_FUNCTION_DEF(tensor, tostring);
  33. LUA_FUNCTION_DEF(tensor, index);
  34. LUA_FUNCTION_DEF(tensor, newindex);
  35. LUA_FUNCTION_DEF(tensor, len);
  36. LUA_FUNCTION_DEF(tensor, eigen);
  37. LUA_FUNCTION_DEF(tensor, mean);
  38. LUA_FUNCTION_DEF(tensor, transpose);
  39. LUA_FUNCTION_DEF(tensor, has_blas);
  40. LUA_FUNCTION_DEF(tensor, scatter_matrix);
  41. static luaL_reg rspamd_tensor_f[] = {
  42. LUA_INTERFACE_DEF(tensor, load),
  43. LUA_INTERFACE_DEF(tensor, new),
  44. LUA_INTERFACE_DEF(tensor, fromtable),
  45. LUA_INTERFACE_DEF(tensor, has_blas),
  46. LUA_INTERFACE_DEF(tensor, scatter_matrix),
  47. {NULL, NULL},
  48. };
  49. static luaL_reg rspamd_tensor_m[] = {
  50. LUA_INTERFACE_DEF(tensor, save),
  51. {"__gc", lua_tensor_destroy},
  52. {"__mul", lua_tensor_mul},
  53. {"mul", lua_tensor_mul},
  54. {"tostring", lua_tensor_tostring},
  55. {"__tostring", lua_tensor_tostring},
  56. {"__index", lua_tensor_index},
  57. {"__newindex", lua_tensor_newindex},
  58. {"__len", lua_tensor_len},
  59. LUA_INTERFACE_DEF(tensor, eigen),
  60. LUA_INTERFACE_DEF(tensor, mean),
  61. LUA_INTERFACE_DEF(tensor, transpose),
  62. {NULL, NULL},
  63. };
  64. struct rspamd_lua_tensor *
  65. lua_newtensor(lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
  66. {
  67. struct rspamd_lua_tensor *res;
  68. res = lua_newuserdata(L, sizeof(struct rspamd_lua_tensor));
  69. memset(res, 0, sizeof(*res));
  70. res->ndims = ndims;
  71. res->size = 1;
  72. for (unsigned int i = 0; i < ndims; i++) {
  73. res->size *= dim[i];
  74. res->dim[i] = dim[i];
  75. }
  76. /* To avoid allocating large stuff in Lua */
  77. if (own) {
  78. res->data = g_malloc(sizeof(rspamd_tensor_num_t) * res->size);
  79. if (zero_fill) {
  80. memset(res->data, 0, sizeof(rspamd_tensor_num_t) * res->size);
  81. }
  82. }
  83. else {
  84. /* Mark size negative to distinguish */
  85. res->size = -(res->size);
  86. }
  87. rspamd_lua_setclass(L, rspamd_tensor_classname, -1);
  88. return res;
  89. }
  90. /***
  91. * @function tensor.new(ndims, [dim1, ... dimN])
  92. * Creates a new zero filled tensor with the specific number of dimensions
  93. * @return
  94. */
  95. static int
  96. lua_tensor_new(lua_State *L)
  97. {
  98. int ndims = luaL_checkinteger(L, 1);
  99. if (ndims > 0 && ndims <= 2) {
  100. int *dims = g_alloca(sizeof(int) * ndims);
  101. for (unsigned int i = 0; i < ndims; i++) {
  102. dims[i] = lua_tointeger(L, i + 2);
  103. }
  104. (void) lua_newtensor(L, ndims, dims, true, true);
  105. }
  106. else {
  107. return luaL_error(L, "incorrect dimensions number: %d", ndims);
  108. }
  109. return 1;
  110. }
  111. /***
  112. * @function tensor.fromtable(tbl)
  113. * Creates a new zero filled tensor with the specific number of dimensions
  114. * @return
  115. */
  116. static int
  117. lua_tensor_fromtable(lua_State *L)
  118. {
  119. if (lua_istable(L, 1)) {
  120. lua_rawgeti(L, 1, 1);
  121. if (lua_isnumber(L, -1)) {
  122. lua_pop(L, 1);
  123. /* Input vector */
  124. int dims[2];
  125. dims[0] = 1;
  126. dims[1] = rspamd_lua_table_size(L, 1);
  127. struct rspamd_lua_tensor *res = lua_newtensor(L, 2,
  128. dims, false, true);
  129. for (unsigned int i = 0; i < dims[1]; i++) {
  130. lua_rawgeti(L, 1, i + 1);
  131. res->data[i] = lua_tonumber(L, -1);
  132. lua_pop(L, 1);
  133. }
  134. }
  135. else if (lua_istable(L, -1)) {
  136. /* Input matrix */
  137. lua_pop(L, 1);
  138. /* Calculate the overall size */
  139. int nrows = rspamd_lua_table_size(L, 1), ncols = 0;
  140. int err;
  141. for (int i = 0; i < nrows; i++) {
  142. lua_rawgeti(L, 1, i + 1);
  143. if (ncols == 0) {
  144. ncols = rspamd_lua_table_size(L, -1);
  145. if (ncols == 0) {
  146. lua_pop(L, 1);
  147. err = luaL_error(L, "invalid params at pos %d: "
  148. "bad input dimension %d",
  149. i,
  150. (int) ncols);
  151. return err;
  152. }
  153. }
  154. else {
  155. if (ncols != rspamd_lua_table_size(L, -1)) {
  156. int t = rspamd_lua_table_size(L, -1);
  157. lua_pop(L, 1);
  158. err = luaL_error(L, "invalid params at pos %d: "
  159. "bad input dimension %d; %d expected",
  160. i,
  161. t,
  162. ncols);
  163. return err;
  164. }
  165. }
  166. lua_pop(L, 1);
  167. }
  168. int dims[2];
  169. dims[0] = nrows;
  170. dims[1] = ncols;
  171. struct rspamd_lua_tensor *res = lua_newtensor(L, 2,
  172. dims, false, true);
  173. for (int i = 0; i < nrows; i++) {
  174. lua_rawgeti(L, 1, i + 1);
  175. for (int j = 0; j < ncols; j++) {
  176. lua_rawgeti(L, -1, j + 1);
  177. res->data[i * ncols + j] = lua_tonumber(L, -1);
  178. lua_pop(L, 1);
  179. }
  180. lua_pop(L, 1);
  181. }
  182. }
  183. else {
  184. lua_pop(L, 1);
  185. return luaL_error(L, "incorrect table");
  186. }
  187. }
  188. else {
  189. return luaL_error(L, "incorrect input");
  190. }
  191. return 1;
  192. }
  193. /***
  194. * @method tensor:destroy()
  195. * Tensor destructor
  196. * @return
  197. */
  198. static int
  199. lua_tensor_destroy(lua_State *L)
  200. {
  201. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1);
  202. if (t) {
  203. if (t->size > 0) {
  204. g_free(t->data);
  205. }
  206. }
  207. return 0;
  208. }
  209. /***
  210. * @method tensor:save()
  211. * Tensor serialisation function
  212. * @return
  213. */
  214. static int
  215. lua_tensor_save(lua_State *L)
  216. {
  217. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1);
  218. int size;
  219. if (t) {
  220. if (t->size > 0) {
  221. size = t->size;
  222. }
  223. else {
  224. size = -(t->size);
  225. }
  226. gsize sz = sizeof(int) * 4 + size * sizeof(rspamd_tensor_num_t);
  227. unsigned char *data;
  228. struct rspamd_lua_text *out = lua_new_text(L, NULL, 0, TRUE);
  229. data = g_malloc(sz);
  230. memcpy(data, &t->ndims, sizeof(int));
  231. memcpy(data + sizeof(int), &size, sizeof(int));
  232. memcpy(data + 2 * sizeof(int), t->dim, sizeof(int) * 2);
  233. memcpy(data + 4 * sizeof(int), t->data,
  234. size * sizeof(rspamd_tensor_num_t));
  235. out->start = (const char *) data;
  236. out->len = sz;
  237. }
  238. else {
  239. return luaL_error(L, "invalid arguments");
  240. }
  241. return 1;
  242. }
  243. static int
  244. lua_tensor_tostring(lua_State *L)
  245. {
  246. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1);
  247. if (t) {
  248. GString *out = g_string_sized_new(128);
  249. if (t->ndims == 1) {
  250. /* Print as a vector */
  251. for (int i = 0; i < t->dim[0]; i++) {
  252. rspamd_printf_gstring(out, "%.4f ", t->data[i]);
  253. }
  254. /* Trim last space */
  255. out->len--;
  256. }
  257. else {
  258. for (int i = 0; i < t->dim[0]; i++) {
  259. for (int j = 0; j < t->dim[1]; j++) {
  260. rspamd_printf_gstring(out, "%.4f ",
  261. t->data[i * t->dim[1] + j]);
  262. }
  263. /* Trim last space */
  264. out->len--;
  265. rspamd_printf_gstring(out, "\n");
  266. }
  267. /* Trim last ; */
  268. out->len--;
  269. }
  270. lua_pushlstring(L, out->str, out->len);
  271. g_string_free(out, TRUE);
  272. }
  273. else {
  274. return luaL_error(L, "invalid arguments");
  275. }
  276. return 1;
  277. }
  278. static int
  279. lua_tensor_index(lua_State *L)
  280. {
  281. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1);
  282. int idx;
  283. if (t) {
  284. if (lua_isnumber(L, 2)) {
  285. idx = lua_tointeger(L, 2);
  286. if (t->ndims == 1) {
  287. /* Individual element */
  288. if (idx <= t->dim[0]) {
  289. lua_pushnumber(L, t->data[idx - 1]);
  290. }
  291. else {
  292. lua_pushnil(L);
  293. }
  294. }
  295. else {
  296. /* Push row */
  297. int dim = t->dim[1];
  298. if (idx <= t->dim[0]) {
  299. /* Non-owning tensor */
  300. struct rspamd_lua_tensor *res =
  301. lua_newtensor(L, 1, &dim, false, false);
  302. res->data = &t->data[(idx - 1) * t->dim[1]];
  303. }
  304. else {
  305. lua_pushnil(L);
  306. }
  307. }
  308. }
  309. else if (lua_isstring(L, 2)) {
  310. /* Access to methods */
  311. lua_getmetatable(L, 1);
  312. lua_pushvalue(L, 2);
  313. lua_rawget(L, -2);
  314. }
  315. }
  316. return 1;
  317. }
  318. static int
  319. lua_tensor_newindex(lua_State *L)
  320. {
  321. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1);
  322. int idx;
  323. if (t) {
  324. if (lua_isnumber(L, 2)) {
  325. idx = lua_tointeger(L, 2);
  326. if (t->ndims == 1) {
  327. /* Individual element */
  328. if (idx <= t->dim[0] && idx > 0) {
  329. rspamd_tensor_num_t value = lua_tonumber(L, 3), old;
  330. old = t->data[idx - 1];
  331. t->data[idx - 1] = value;
  332. lua_pushnumber(L, old);
  333. }
  334. else {
  335. return luaL_error(L, "invalid index: %d", idx);
  336. }
  337. }
  338. else {
  339. if (lua_isnumber(L, 3)) {
  340. return luaL_error(L, "cannot assign number to a row");
  341. }
  342. else if (lua_isuserdata(L, 3)) {
  343. /* Tensor assignment */
  344. struct rspamd_lua_tensor *row = lua_check_tensor(L, 3);
  345. if (row) {
  346. if (row->ndims == 1) {
  347. if (row->dim[0] == t->dim[1]) {
  348. if (idx > 0 && idx <= t->dim[0]) {
  349. idx--; /* Zero based index */
  350. memcpy(&t->data[idx * t->dim[1]],
  351. row->data,
  352. t->dim[1] * sizeof(rspamd_tensor_num_t));
  353. return 0;
  354. }
  355. else {
  356. return luaL_error(L, "invalid index: %d", idx);
  357. }
  358. }
  359. }
  360. else {
  361. return luaL_error(L, "cannot assign matrix to row");
  362. }
  363. }
  364. else {
  365. return luaL_error(L, "cannot assign row, invalid tensor");
  366. }
  367. }
  368. else {
  369. /* TODO: add table assignment */
  370. return luaL_error(L, "cannot assign row, not a tensor");
  371. }
  372. }
  373. }
  374. else {
  375. /* Access to methods? NYI */
  376. return luaL_error(L, "cannot assign method of a tensor");
  377. }
  378. }
  379. return 1;
  380. }
  381. /***
  382. * @method tensor:mul(other, [transA, [transB]])
  383. * Multiply two tensors (optionally transposed) and return a new tensor
  384. * @return
  385. */
  386. static int
  387. lua_tensor_mul(lua_State *L)
  388. {
  389. struct rspamd_lua_tensor *t1 = lua_check_tensor(L, 1),
  390. *t2 = lua_check_tensor(L, 2), *res;
  391. int transA = 0, transB = 0;
  392. if (lua_isboolean(L, 3)) {
  393. transA = lua_toboolean(L, 3);
  394. }
  395. if (lua_isboolean(L, 4)) {
  396. transB = lua_toboolean(L, 4);
  397. }
  398. if (t1 && t2) {
  399. int dims[2], shadow_dims[2];
  400. dims[0] = abs(transA ? t1->dim[1] : t1->dim[0]);
  401. shadow_dims[0] = abs(transB ? t2->dim[1] : t2->dim[0]);
  402. dims[1] = abs(transB ? t2->dim[0] : t2->dim[1]);
  403. shadow_dims[1] = abs(transA ? t1->dim[0] : t1->dim[1]);
  404. if (shadow_dims[0] != shadow_dims[1]) {
  405. return luaL_error(L, "incompatible dimensions %d x %d * %d x %d",
  406. dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
  407. }
  408. else if (shadow_dims[0] == 0) {
  409. /* Row * Column -> matrix */
  410. shadow_dims[0] = 1;
  411. shadow_dims[1] = 1;
  412. }
  413. if (dims[0] == 0) {
  414. /* Column */
  415. dims[0] = 1;
  416. if (dims[1] == 0) {
  417. /* Column * row -> number */
  418. dims[1] = 1;
  419. }
  420. res = lua_newtensor(L, 2, dims, true, true);
  421. }
  422. else if (dims[1] == 0) {
  423. /* Row */
  424. res = lua_newtensor(L, 1, dims, true, true);
  425. dims[1] = 1;
  426. }
  427. else {
  428. res = lua_newtensor(L, 2, dims, true, true);
  429. }
  430. kad_sgemm_simple(transA, transB, dims[0], dims[1], shadow_dims[0],
  431. t1->data, t2->data, res->data);
  432. }
  433. else {
  434. return luaL_error(L, "invalid arguments");
  435. }
  436. return 1;
  437. }
  438. /***
  439. * @function tensor.load(rspamd_text)
  440. * Deserialize tensor
  441. * @return
  442. */
  443. static int
  444. lua_tensor_load(lua_State *L)
  445. {
  446. const unsigned char *data;
  447. gsize sz;
  448. if (lua_type(L, 1) == LUA_TUSERDATA) {
  449. struct rspamd_lua_text *t = lua_check_text(L, 1);
  450. if (!t) {
  451. return luaL_error(L, "invalid argument");
  452. }
  453. data = (const unsigned char *) t->start;
  454. sz = t->len;
  455. }
  456. else {
  457. data = (const unsigned char *) lua_tolstring(L, 1, &sz);
  458. }
  459. if (sz >= sizeof(int) * 4) {
  460. int ndims, nelts, dims[2];
  461. memcpy(&ndims, data, sizeof(int));
  462. memcpy(&nelts, data + sizeof(int), sizeof(int));
  463. memcpy(dims, data + sizeof(int) * 2, sizeof(int) * 2);
  464. if (sz == nelts * sizeof(rspamd_tensor_num_t) + sizeof(int) * 4) {
  465. if (ndims == 1) {
  466. if (nelts == dims[0]) {
  467. struct rspamd_lua_tensor *t = lua_newtensor(L, ndims, dims, false, true);
  468. memcpy(t->data, data + sizeof(int) * 4, nelts * sizeof(rspamd_tensor_num_t));
  469. }
  470. else {
  471. return luaL_error(L, "invalid argument: bad dims: %d x %d != %d",
  472. dims[0], 1, nelts);
  473. }
  474. }
  475. else if (ndims == 2) {
  476. if (nelts == dims[0] * dims[1]) {
  477. struct rspamd_lua_tensor *t = lua_newtensor(L, ndims, dims, false, true);
  478. memcpy(t->data, data + sizeof(int) * 4, nelts * sizeof(rspamd_tensor_num_t));
  479. }
  480. else {
  481. return luaL_error(L, "invalid argument: bad dims: %d x %d != %d",
  482. dims[0], dims[1], nelts);
  483. }
  484. }
  485. else {
  486. return luaL_error(L, "invalid argument: bad ndims: %d", ndims);
  487. }
  488. }
  489. else {
  490. return luaL_error(L, "invalid size: %d, %d required, %d elts", (int) sz,
  491. (int) (nelts * sizeof(rspamd_tensor_num_t) + sizeof(int) * 4),
  492. nelts);
  493. }
  494. }
  495. else {
  496. return luaL_error(L, "invalid arguments; sz = %d", (int) sz);
  497. }
  498. return 1;
  499. }
  500. static int
  501. lua_tensor_len(lua_State *L)
  502. {
  503. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1);
  504. int nret = 1;
  505. if (t) {
  506. /* Return the main dimension first */
  507. if (t->ndims == 1) {
  508. lua_pushinteger(L, t->dim[0]);
  509. }
  510. else {
  511. lua_pushinteger(L, t->dim[0]);
  512. lua_pushinteger(L, t->dim[1]);
  513. nret = 2;
  514. }
  515. }
  516. else {
  517. return luaL_error(L, "invalid arguments");
  518. }
  519. return nret;
  520. }
  521. static int
  522. lua_tensor_eigen(lua_State *L)
  523. {
  524. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1), *eigen;
  525. if (t) {
  526. if (t->ndims != 2 || t->dim[0] != t->dim[1]) {
  527. return luaL_error(L, "expected square matrix NxN but got %dx%d",
  528. t->dim[0], t->dim[1]);
  529. }
  530. eigen = lua_newtensor(L, 1, &t->dim[0], true, true);
  531. if (!kad_ssyev_simple(t->dim[0], t->data, eigen->data)) {
  532. lua_pop(L, 1);
  533. return luaL_error(L, "kad_ssyev_simple failed (no blas?)");
  534. }
  535. }
  536. else {
  537. return luaL_error(L, "invalid arguments");
  538. }
  539. return 1;
  540. }
  541. static inline rspamd_tensor_num_t
  542. mean_vec(rspamd_tensor_num_t *x, gsize n)
  543. {
  544. float sum = rspamd_sum_floats(x, &n);
  545. return sum / (rspamd_tensor_num_t) n;
  546. }
  547. static int
  548. lua_tensor_mean(lua_State *L)
  549. {
  550. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1);
  551. if (t) {
  552. if (t->ndims == 1) {
  553. /* Mean of all elements in a vector */
  554. lua_pushnumber(L, mean_vec(t->data, t->dim[0]));
  555. }
  556. else {
  557. /* Row-wise mean vector output */
  558. struct rspamd_lua_tensor *res;
  559. res = lua_newtensor(L, 1, &t->dim[0], false, true);
  560. for (int i = 0; i < t->dim[0]; i++) {
  561. res->data[i] = mean_vec(&t->data[i * t->dim[1]], t->dim[1]);
  562. }
  563. }
  564. }
  565. else {
  566. return luaL_error(L, "invalid arguments");
  567. }
  568. return 1;
  569. }
  570. static int
  571. lua_tensor_transpose(lua_State *L)
  572. {
  573. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1), *res;
  574. int dims[2];
  575. if (t) {
  576. if (t->ndims == 1) {
  577. /* Row to column */
  578. dims[0] = 1;
  579. dims[1] = t->dim[0];
  580. res = lua_newtensor(L, 2, dims, false, true);
  581. memcpy(res->data, t->data, t->dim[0] * sizeof(rspamd_tensor_num_t));
  582. }
  583. else {
  584. /* Cache friendly algorithm */
  585. struct rspamd_lua_tensor *res;
  586. dims[0] = t->dim[1];
  587. dims[1] = t->dim[0];
  588. res = lua_newtensor(L, 2, dims, false, true);
  589. static const int block = 32;
  590. for (int i = 0; i < t->dim[0]; i += block) {
  591. for (int j = 0; j < t->dim[1]; ++j) {
  592. for (int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) {
  593. res->data[j * t->dim[0] + i + boff] =
  594. t->data[(i + boff) * t->dim[1] + j];
  595. }
  596. }
  597. }
  598. }
  599. }
  600. else {
  601. return luaL_error(L, "invalid arguments");
  602. }
  603. return 1;
  604. }
  605. static int
  606. lua_tensor_has_blas(lua_State *L)
  607. {
  608. #ifdef HAVE_CBLAS
  609. lua_pushboolean(L, true);
  610. #else
  611. lua_pushboolean(L, false);
  612. #endif
  613. return 1;
  614. }
  615. static int
  616. lua_tensor_scatter_matrix(lua_State *L)
  617. {
  618. struct rspamd_lua_tensor *t = lua_check_tensor(L, 1), *res;
  619. int dims[2];
  620. if (t) {
  621. if (t->ndims != 2) {
  622. return luaL_error(L, "matrix required");
  623. }
  624. /* X * X square matrix */
  625. dims[0] = t->dim[1];
  626. dims[1] = t->dim[1];
  627. res = lua_newtensor(L, 2, dims, true, true);
  628. /* Auxiliary vars */
  629. rspamd_tensor_num_t *means, /* means vector */
  630. *tmp_row, /* temp row for Kahan's algorithm */
  631. *tmp_square /* temp matrix for multiplications */;
  632. means = g_malloc0(sizeof(rspamd_tensor_num_t) * t->dim[1]);
  633. tmp_row = g_malloc0(sizeof(rspamd_tensor_num_t) * t->dim[1]);
  634. tmp_square = g_malloc(sizeof(rspamd_tensor_num_t) * t->dim[1] * t->dim[1]);
  635. /*
  636. * Column based means
  637. * means will have s, tmp_row will have c
  638. */
  639. for (int i = 0; i < t->dim[0]; i++) {
  640. /* Cycle by rows */
  641. for (int j = 0; j < t->dim[1]; j++) {
  642. rspamd_tensor_num_t v = t->data[i * t->dim[1] + j];
  643. rspamd_tensor_num_t y = v - tmp_row[j];
  644. rspamd_tensor_num_t st = means[j] + y;
  645. tmp_row[j] = (st - means[j]) - y;
  646. means[j] = st;
  647. }
  648. }
  649. for (int j = 0; j < t->dim[1]; j++) {
  650. means[j] /= t->dim[0];
  651. }
  652. for (int i = 0; i < t->dim[0]; i++) {
  653. /* Update for each sample */
  654. for (int j = 0; j < t->dim[1]; j++) {
  655. tmp_row[j] = t->data[i * t->dim[1] + j] - means[j];
  656. }
  657. memset(tmp_square, 0, t->dim[1] * t->dim[1] * sizeof(rspamd_tensor_num_t));
  658. kad_sgemm_simple(1, 0, t->dim[1], t->dim[1], 1,
  659. tmp_row, tmp_row, tmp_square);
  660. for (int j = 0; j < t->dim[1]; j++) {
  661. kad_saxpy(t->dim[1], 1.0, &tmp_square[j * t->dim[1]],
  662. &res->data[j * t->dim[1]]);
  663. }
  664. }
  665. g_free(tmp_row);
  666. g_free(means);
  667. g_free(tmp_square);
  668. }
  669. else {
  670. return luaL_error(L, "tensor required");
  671. }
  672. return 1;
  673. }
  674. static int
  675. lua_load_tensor(lua_State *L)
  676. {
  677. lua_newtable(L);
  678. luaL_register(L, NULL, rspamd_tensor_f);
  679. return 1;
  680. }
  681. void luaopen_tensor(lua_State *L)
  682. {
  683. /* Metatables */
  684. rspamd_lua_new_class(L, rspamd_tensor_classname, rspamd_tensor_m);
  685. lua_pop(L, 1); /* No need in metatable... */
  686. rspamd_lua_add_preload(L, "rspamd_tensor", lua_load_tensor);
  687. lua_settop(L, 0);
  688. }
  689. struct rspamd_lua_tensor *
  690. lua_check_tensor(lua_State *L, int pos)
  691. {
  692. void *ud = rspamd_lua_check_udata(L, pos, rspamd_tensor_classname);
  693. luaL_argcheck(L, ud != NULL, pos, "'tensor' expected");
  694. return ud ? ((struct rspamd_lua_tensor *) ud) : NULL;
  695. }