You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_tensor.c 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. /*-
  2. * Copyright 2020 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "lua_tensor.h"
  18. #include "contrib/kann/kautodiff.h"
  19. /***
  20. * @module rspamd_tensor
  21. * `rspamd_tensor` is a simple Lua library to abstract matrices and vectors
  22. * Internally, they are represented as arrays of float variables
  23. * So far, merely 1D and 2D tensors are supported
  24. */
  25. LUA_FUNCTION_DEF (tensor, load);
  26. LUA_FUNCTION_DEF (tensor, save);
  27. LUA_FUNCTION_DEF (tensor, new);
  28. LUA_FUNCTION_DEF (tensor, fromtable);
  29. LUA_FUNCTION_DEF (tensor, destroy);
  30. LUA_FUNCTION_DEF (tensor, mul);
  31. LUA_FUNCTION_DEF (tensor, tostring);
  32. LUA_FUNCTION_DEF (tensor, index);
  33. LUA_FUNCTION_DEF (tensor, newindex);
  34. LUA_FUNCTION_DEF (tensor, len);
  35. LUA_FUNCTION_DEF (tensor, eigen);
  36. LUA_FUNCTION_DEF (tensor, mean);
  37. LUA_FUNCTION_DEF (tensor, transpose);
  38. static luaL_reg rspamd_tensor_f[] = {
  39. LUA_INTERFACE_DEF (tensor, load),
  40. LUA_INTERFACE_DEF (tensor, new),
  41. LUA_INTERFACE_DEF (tensor, fromtable),
  42. {NULL, NULL},
  43. };
  44. static luaL_reg rspamd_tensor_m[] = {
  45. LUA_INTERFACE_DEF (tensor, save),
  46. {"__gc", lua_tensor_destroy},
  47. {"__mul", lua_tensor_mul},
  48. {"mul", lua_tensor_mul},
  49. {"tostring", lua_tensor_tostring},
  50. {"__tostring", lua_tensor_tostring},
  51. {"__index", lua_tensor_index},
  52. {"__newindex", lua_tensor_newindex},
  53. {"__len", lua_tensor_len},
  54. LUA_INTERFACE_DEF (tensor, eigen),
  55. LUA_INTERFACE_DEF (tensor, mean),
  56. LUA_INTERFACE_DEF (tensor, transpose),
  57. {NULL, NULL},
  58. };
  59. struct rspamd_lua_tensor *
  60. lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
  61. {
  62. struct rspamd_lua_tensor *res;
  63. res = lua_newuserdata (L, sizeof (struct rspamd_lua_tensor));
  64. memset (res, 0, sizeof (*res));
  65. res->ndims = ndims;
  66. res->size = 1;
  67. for (guint i = 0; i < ndims; i ++) {
  68. res->size *= dim[i];
  69. res->dim[i] = dim[i];
  70. }
  71. /* To avoid allocating large stuff in Lua */
  72. if (own) {
  73. res->data = g_malloc (sizeof (rspamd_tensor_num_t) * res->size);
  74. if (zero_fill) {
  75. memset (res->data, 0, sizeof (rspamd_tensor_num_t) * res->size);
  76. }
  77. }
  78. else {
  79. /* Mark size negative to distinguish */
  80. res->size = -(res->size);
  81. }
  82. rspamd_lua_setclass (L, TENSOR_CLASS, -1);
  83. return res;
  84. }
  85. /***
  86. * @function tensor.new(ndims, [dim1, ... dimN])
  87. * Creates a new zero filled tensor with the specific number of dimensions
  88. * @return
  89. */
  90. static gint
  91. lua_tensor_new (lua_State *L)
  92. {
  93. gint ndims = luaL_checkinteger (L, 1);
  94. if (ndims > 0 && ndims <= 2) {
  95. gint *dims = g_alloca (sizeof (gint) * ndims);
  96. for (guint i = 0; i < ndims; i ++) {
  97. dims[i] = lua_tointeger (L, i + 2);
  98. }
  99. (void)lua_newtensor (L, ndims, dims, true, true);
  100. }
  101. else {
  102. return luaL_error (L, "incorrect dimensions number: %d", ndims);
  103. }
  104. return 1;
  105. }
  106. /***
  107. * @function tensor.fromtable(tbl)
  108. * Creates a new zero filled tensor with the specific number of dimensions
  109. * @return
  110. */
  111. static gint
  112. lua_tensor_fromtable (lua_State *L)
  113. {
  114. if (lua_istable (L, 1)) {
  115. lua_rawgeti (L, 1, 1);
  116. if (lua_isnumber (L, -1)) {
  117. lua_pop (L, 1);
  118. /* Input vector */
  119. gint dims[2];
  120. dims[0] = 1;
  121. dims[1] = rspamd_lua_table_size (L, 1);
  122. struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
  123. dims, false, true);
  124. for (guint i = 0; i < dims[1]; i ++) {
  125. lua_rawgeti (L, 1, i + 1);
  126. res->data[i] = lua_tonumber (L, -1);
  127. lua_pop (L, 1);
  128. }
  129. }
  130. else if (lua_istable (L, -1)) {
  131. /* Input matrix */
  132. lua_pop (L, 1);
  133. /* Calculate the overall size */
  134. gint nrows = rspamd_lua_table_size (L, 1), ncols = 0;
  135. gint err;
  136. for (gint i = 0; i < nrows; i ++) {
  137. lua_rawgeti (L, 1, i + 1);
  138. if (ncols == 0) {
  139. ncols = rspamd_lua_table_size (L, -1);
  140. if (ncols == 0) {
  141. lua_pop (L, 1);
  142. err = luaL_error (L, "invalid params at pos %d: "
  143. "bad input dimension %d",
  144. i,
  145. (int)ncols);
  146. return err;
  147. }
  148. }
  149. else {
  150. if (ncols != rspamd_lua_table_size (L, -1)) {
  151. gint t = rspamd_lua_table_size (L, -1);
  152. lua_pop (L, 1);
  153. err = luaL_error (L, "invalid params at pos %d: "
  154. "bad input dimension %d; %d expected",
  155. i,
  156. t,
  157. ncols);
  158. return err;
  159. }
  160. }
  161. lua_pop (L, 1);
  162. }
  163. gint dims[2];
  164. dims[0] = nrows;
  165. dims[1] = ncols;
  166. struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
  167. dims, false, true);
  168. for (gint i = 0; i < nrows; i ++) {
  169. lua_rawgeti (L, 1, i + 1);
  170. for (gint j = 0; j < ncols; j++) {
  171. lua_rawgeti (L, -1, j + 1);
  172. res->data[i * ncols + j] = lua_tonumber (L, -1);
  173. lua_pop (L, 1);
  174. }
  175. lua_pop (L, 1);
  176. }
  177. }
  178. else {
  179. lua_pop (L, 1);
  180. return luaL_error (L, "incorrect table");
  181. }
  182. }
  183. else {
  184. return luaL_error (L, "incorrect input");
  185. }
  186. return 1;
  187. }
  188. /***
  189. * @method tensor:destroy()
  190. * Tensor destructor
  191. * @return
  192. */
  193. static gint
  194. lua_tensor_destroy (lua_State *L)
  195. {
  196. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
  197. if (t) {
  198. if (t->size > 0) {
  199. g_free (t->data);
  200. }
  201. }
  202. return 0;
  203. }
  204. /***
  205. * @method tensor:save()
  206. * Tensor serialisation function
  207. * @return
  208. */
  209. static gint
  210. lua_tensor_save (lua_State *L)
  211. {
  212. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
  213. gint size;
  214. if (t) {
  215. if (t->size > 0) {
  216. size = t->size;
  217. }
  218. else {
  219. size = -(t->size);
  220. }
  221. gsize sz = sizeof (gint) * 4 + size * sizeof (rspamd_tensor_num_t);
  222. guchar *data;
  223. struct rspamd_lua_text *out = lua_new_text (L, NULL, 0, TRUE);
  224. data = g_malloc (sz);
  225. memcpy (data, &t->ndims, sizeof (int));
  226. memcpy (data + sizeof (int), &size, sizeof (int));
  227. memcpy (data + 2 * sizeof (int), t->dim, sizeof (int) * 2);
  228. memcpy (data + 4 * sizeof (int), t->data,
  229. size * sizeof (rspamd_tensor_num_t));
  230. out->start = (const gchar *)data;
  231. out->len = sz;
  232. }
  233. else {
  234. return luaL_error (L, "invalid arguments");
  235. }
  236. return 1;
  237. }
  238. static gint
  239. lua_tensor_tostring (lua_State *L)
  240. {
  241. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
  242. if (t) {
  243. GString *out = g_string_sized_new (128);
  244. if (t->ndims == 1) {
  245. /* Print as a vector */
  246. for (gint i = 0; i < t->dim[0]; i ++) {
  247. rspamd_printf_gstring (out, "%.4f ", t->data[i]);
  248. }
  249. /* Trim last space */
  250. out->len --;
  251. }
  252. else {
  253. for (gint i = 0; i < t->dim[0]; i ++) {
  254. for (gint j = 0; j < t->dim[1]; j ++) {
  255. rspamd_printf_gstring (out, "%.4f ",
  256. t->data[i * t->dim[1] + j]);
  257. }
  258. /* Trim last space */
  259. out->len --;
  260. rspamd_printf_gstring (out, "\n");
  261. }
  262. /* Trim last ; */
  263. out->len --;
  264. }
  265. lua_pushlstring (L, out->str, out->len);
  266. g_string_free (out, TRUE);
  267. }
  268. else {
  269. return luaL_error (L, "invalid arguments");
  270. }
  271. return 1;
  272. }
  273. static gint
  274. lua_tensor_index (lua_State *L)
  275. {
  276. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
  277. gint idx;
  278. if (t) {
  279. if (lua_isnumber (L, 2)) {
  280. idx = lua_tointeger (L, 2);
  281. if (t->ndims == 1) {
  282. /* Individual element */
  283. if (idx <= t->dim[0]) {
  284. lua_pushnumber (L, t->data[idx - 1]);
  285. }
  286. else {
  287. lua_pushnil (L);
  288. }
  289. }
  290. else {
  291. /* Push row */
  292. gint dim = t->dim[1];
  293. if (idx <= t->dim[0]) {
  294. /* Non-owning tensor */
  295. struct rspamd_lua_tensor *res =
  296. lua_newtensor (L, 1, &dim, false, false);
  297. res->data = &t->data[(idx - 1) * t->dim[1]];
  298. }
  299. else {
  300. lua_pushnil (L);
  301. }
  302. }
  303. }
  304. else if (lua_isstring (L, 2)) {
  305. /* Access to methods */
  306. lua_getmetatable (L, 1);
  307. lua_pushvalue (L, 2);
  308. lua_rawget (L, -2);
  309. }
  310. }
  311. return 1;
  312. }
  313. static gint
  314. lua_tensor_newindex (lua_State *L)
  315. {
  316. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
  317. gint idx;
  318. if (t) {
  319. if (lua_isnumber (L, 2)) {
  320. idx = lua_tointeger (L, 2);
  321. if (t->ndims == 1) {
  322. /* Individual element */
  323. if (idx <= t->dim[0] && idx > 0) {
  324. rspamd_tensor_num_t value = lua_tonumber (L, 3), old;
  325. old = t->data[idx - 1];
  326. t->data[idx - 1] = value;
  327. lua_pushnumber (L, old);
  328. }
  329. else {
  330. return luaL_error (L, "invalid index: %d", idx);
  331. }
  332. }
  333. else {
  334. if (lua_isnumber (L, 3)) {
  335. return luaL_error (L, "cannot assign number to a row");
  336. }
  337. else if (lua_isuserdata (L, 3)) {
  338. /* Tensor assignment */
  339. struct rspamd_lua_tensor *row = lua_check_tensor (L, 3);
  340. if (row) {
  341. if (row->ndims == 1) {
  342. if (row->dim[0] == t->dim[1]) {
  343. if (idx > 0 && idx <= t->dim[0]) {
  344. idx --; /* Zero based index */
  345. memcpy (&t->data[idx * t->dim[1]],
  346. row->data,
  347. t->dim[1] * sizeof (rspamd_tensor_num_t));
  348. return 0;
  349. }
  350. else {
  351. return luaL_error (L, "invalid index: %d", idx);
  352. }
  353. }
  354. }
  355. else {
  356. return luaL_error (L, "cannot assign matrix to row");
  357. }
  358. }
  359. else {
  360. return luaL_error (L, "cannot assign row, invalid tensor");
  361. }
  362. }
  363. else {
  364. /* TODO: add table assignment */
  365. return luaL_error (L, "cannot assign row, not a tensor");
  366. }
  367. }
  368. }
  369. else {
  370. /* Access to methods? NYI */
  371. return luaL_error (L, "cannot assign method of a tensor");
  372. }
  373. }
  374. return 1;
  375. }
  376. /***
  377. * @method tensor:mul(other, [transA, [transB]])
  378. * Multiply two tensors (optionally transposed) and return a new tensor
  379. * @return
  380. */
  381. static gint
  382. lua_tensor_mul (lua_State *L)
  383. {
  384. struct rspamd_lua_tensor *t1 = lua_check_tensor (L, 1),
  385. *t2 = lua_check_tensor (L, 2), *res;
  386. int transA = 0, transB = 0;
  387. if (lua_isboolean (L, 3)) {
  388. transA = lua_toboolean (L, 3);
  389. }
  390. if (lua_isboolean (L, 4)) {
  391. transB = lua_toboolean (L, 4);
  392. }
  393. if (t1 && t2) {
  394. gint dims[2], shadow_dims[2];
  395. dims[0] = abs (transA ? t1->dim[1] : t1->dim[0]);
  396. shadow_dims[0] = abs (transB ? t2->dim[1] : t2->dim[0]);
  397. dims[1] = abs (transB ? t2->dim[0] : t2->dim[1]);
  398. shadow_dims[1] = abs (transA ? t1->dim[0] : t1->dim[1]);
  399. if (shadow_dims[0] != shadow_dims[1]) {
  400. return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",
  401. dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
  402. }
  403. else if (shadow_dims[0] == 0) {
  404. /* Row * Column -> matrix */
  405. shadow_dims[0] = 1;
  406. shadow_dims[1] = 1;
  407. }
  408. if (dims[0] == 0) {
  409. /* Column */
  410. dims[0] = 1;
  411. if (dims[1] == 0) {
  412. /* Column * row -> number */
  413. dims[1] = 1;
  414. }
  415. res = lua_newtensor (L, 2, dims, true, true);
  416. }
  417. else if (dims[1] == 0) {
  418. /* Row */
  419. res = lua_newtensor (L, 1, dims, true, true);
  420. dims[1] = 1;
  421. }
  422. else {
  423. res = lua_newtensor (L, 2, dims, true, true);
  424. }
  425. kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
  426. t1->data, t2->data, res->data);
  427. }
  428. else {
  429. return luaL_error (L, "invalid arguments");
  430. }
  431. return 1;
  432. }
  433. /***
  434. * @function tensor.load(rspamd_text)
  435. * Deserialize tensor
  436. * @return
  437. */
  438. static gint
  439. lua_tensor_load (lua_State *L)
  440. {
  441. const guchar *data;
  442. gsize sz;
  443. if (lua_type (L, 1) == LUA_TUSERDATA) {
  444. struct rspamd_lua_text *t = lua_check_text (L, 1);
  445. if (!t) {
  446. return luaL_error (L, "invalid argument");
  447. }
  448. data = (const guchar *)t->start;
  449. sz = t->len;
  450. }
  451. else {
  452. data = (const guchar *)lua_tolstring (L, 1, &sz);
  453. }
  454. if (sz >= sizeof (gint) * 4) {
  455. int ndims, nelts, dims[2];
  456. memcpy (&ndims, data, sizeof (int));
  457. memcpy (&nelts, data + sizeof (int), sizeof (int));
  458. memcpy (dims, data + sizeof (int) * 2, sizeof (int) * 2);
  459. if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
  460. if (ndims == 1) {
  461. if (nelts == dims[0]) {
  462. struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
  463. memcpy (t->data, data + sizeof (int) * 4, nelts *
  464. sizeof (rspamd_tensor_num_t));
  465. }
  466. else {
  467. return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
  468. dims[0], 1, nelts);
  469. }
  470. }
  471. else if (ndims == 2) {
  472. if (nelts == dims[0] * dims[1]) {
  473. struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
  474. memcpy (t->data, data + sizeof (int) * 4, nelts *
  475. sizeof (rspamd_tensor_num_t));
  476. }
  477. else {
  478. return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
  479. dims[0], dims[1], nelts);
  480. }
  481. }
  482. else {
  483. return luaL_error (L, "invalid argument: bad ndims: %d", ndims);
  484. }
  485. }
  486. else {
  487. return luaL_error (L, "invalid size: %d, %d required, %d elts", (int)sz,
  488. (int)(nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4),
  489. nelts);
  490. }
  491. }
  492. else {
  493. return luaL_error (L, "invalid arguments");
  494. }
  495. return 1;
  496. }
  497. static gint
  498. lua_tensor_len (lua_State *L)
  499. {
  500. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
  501. gint nret = 1;
  502. if (t) {
  503. /* Return the main dimension first */
  504. if (t->ndims == 1) {
  505. lua_pushinteger (L, t->dim[0]);
  506. }
  507. else {
  508. lua_pushinteger (L, t->dim[0]);
  509. lua_pushinteger (L, t->dim[1]);
  510. nret = 2;
  511. }
  512. }
  513. else {
  514. return luaL_error (L, "invalid arguments");
  515. }
  516. return nret;
  517. }
  518. static gint
  519. lua_tensor_eigen (lua_State *L)
  520. {
  521. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *eigen;
  522. if (t) {
  523. if (t->ndims != 2 || t->dim[0] != t->dim[1]) {
  524. return luaL_error (L, "expected square matrix NxN but got %dx%d",
  525. t->dim[0], t->dim[1]);
  526. }
  527. eigen = lua_newtensor (L, 1, &t->dim[0], true, true);
  528. if (!kad_ssyev_simple (t->dim[0], t->data, eigen->data)) {
  529. lua_pop (L, 1);
  530. return luaL_error (L, "kad_ssyev_simple failed (no blas?)");
  531. }
  532. }
  533. else {
  534. return luaL_error (L, "invalid arguments");
  535. }
  536. return 1;
  537. }
  538. static inline rspamd_tensor_num_t
  539. mean_vec (rspamd_tensor_num_t *x, int n)
  540. {
  541. rspamd_tensor_num_t s = 0;
  542. rspamd_tensor_num_t c = 0;
  543. for (int i = 0; i < n; i ++) {
  544. rspamd_tensor_num_t v = x[i];
  545. rspamd_tensor_num_t y = v - c;
  546. rspamd_tensor_num_t t = s + y;
  547. c = (t - s) - y;
  548. s = t;
  549. }
  550. return s / (rspamd_tensor_num_t)n;
  551. }
  552. static gint
  553. lua_tensor_mean (lua_State *L)
  554. {
  555. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
  556. if (t) {
  557. if (t->ndims == 1) {
  558. /* Mean of all elements in a vector */
  559. lua_pushnumber (L, mean_vec (t->data, t->dim[0]));
  560. }
  561. else {
  562. /* Row-wise mean vector output */
  563. struct rspamd_lua_tensor *res;
  564. res = lua_newtensor (L, 1, &t->dim[0], false, true);
  565. for (int i = 0; i < t->dim[0]; i ++) {
  566. res->data[i] = mean_vec (&t->data[i * t->dim[1]], t->dim[1]);
  567. }
  568. }
  569. }
  570. else {
  571. return luaL_error (L, "invalid arguments");
  572. }
  573. return 1;
  574. }
  575. static gint
  576. lua_tensor_transpose (lua_State *L)
  577. {
  578. struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
  579. int dims[2];
  580. if (t) {
  581. if (t->ndims == 1) {
  582. /* Row to column */
  583. dims[0] = 1;
  584. dims[1] = t->dim[0];
  585. res = lua_newtensor (L, 2, dims, false, true);
  586. memcpy (res->data, t->data, t->dim[0] * sizeof (rspamd_tensor_num_t));
  587. }
  588. else {
  589. /* Cache friendly algorithm */
  590. struct rspamd_lua_tensor *res;
  591. dims[0] = t->dim[1];
  592. dims[1] = t->dim[0];
  593. res = lua_newtensor (L, 2, dims, false, true);
  594. static const int block = 32;
  595. for (int i = 0; i < t->dim[0]; i += block) {
  596. for(int j = 0; j < t->dim[1]; ++j) {
  597. for(int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) {
  598. res->data[j * t->dim[0] + i + boff] =
  599. t->data[(i + boff) * t->dim[1] + j];
  600. }
  601. }
  602. }
  603. }
  604. }
  605. else {
  606. return luaL_error (L, "invalid arguments");
  607. }
  608. return 1;
  609. }
  610. static gint
  611. lua_load_tensor (lua_State * L)
  612. {
  613. lua_newtable (L);
  614. luaL_register (L, NULL, rspamd_tensor_f);
  615. return 1;
  616. }
  617. void luaopen_tensor (lua_State *L)
  618. {
  619. /* Metatables */
  620. rspamd_lua_new_class (L, TENSOR_CLASS, rspamd_tensor_m);
  621. lua_pop (L, 1); /* No need in metatable... */
  622. rspamd_lua_add_preload (L, "rspamd_tensor", lua_load_tensor);
  623. lua_settop (L, 0);
  624. }
  625. struct rspamd_lua_tensor *
  626. lua_check_tensor (lua_State *L, int pos)
  627. {
  628. void *ud = rspamd_lua_check_udata (L, pos, TENSOR_CLASS);
  629. luaL_argcheck (L, ud != NULL, pos, "'tensor' expected");
  630. return ud ? ((struct rspamd_lua_tensor *)ud) : NULL;
  631. }