aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-05 21:04:32 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-05 21:05:09 +0100
commitb8216839b2a9f083259b71947bf0caa4b4eef091 (patch)
tree12f672dbf121b02fa94c627e69d502217c05fb38 /src/lua
parent09fc651620d5d55949fb0ad60e402fb5e62668e2 (diff)
downloadrspamd-b8216839b2a9f083259b71947bf0caa4b4eef091.tar.gz
rspamd-b8216839b2a9f083259b71947bf0caa4b4eef091.zip
[Project] Add tensors index method
Diffstat (limited to 'src/lua')
-rw-r--r--src/lua/lua_tensor.c50
1 files changed, 49 insertions, 1 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index 21bdf9673..85aaa2e95 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -32,6 +32,7 @@ LUA_FUNCTION_DEF (tensor, fromtable);
LUA_FUNCTION_DEF (tensor, destroy);
LUA_FUNCTION_DEF (tensor, mul);
LUA_FUNCTION_DEF (tensor, tostring);
+LUA_FUNCTION_DEF (tensor, index);
static luaL_reg rspamd_tensor_f[] = {
LUA_INTERFACE_DEF (tensor, load),
@@ -45,8 +46,9 @@ static luaL_reg rspamd_tensor_m[] = {
{"__gc", lua_tensor_destroy},
{"__mul", lua_tensor_mul},
{"mul", lua_tensor_mul},
- {"__tostring", lua_tensor_tostring},
{"tostring", lua_tensor_tostring},
+ {"__tostring", lua_tensor_tostring},
+ {"__index", lua_tensor_index},
{NULL, NULL},
};
@@ -284,6 +286,52 @@ lua_tensor_tostring (lua_State *L)
return 1;
}
+static gint
+lua_tensor_index (lua_State *L)
+{
+ struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+ gint idx;
+
+ if (t) {
+ if (lua_isnumber (L, 2)) {
+ idx = lua_tointeger (L, 2);
+
+ if (t->ndims == 1) {
+ /* Individual element */
+ if (idx <= t->dim[0]) {
+ lua_pushnumber (L, t->data[idx - 1]);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ }
+ else {
+ /* Push row */
+ gint dim = t->dim[1];
+
+
+ if (idx <= t->dim[0]) {
+ struct rspamd_lua_tensor *res =
+ lua_newtensor (L, 1, &dim, false);
+ for (gint i = 0; i < dim; i++) {
+ res->data[i] = t->data[(idx - 1) * t->dim[1] + i];
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+ }
+ }
+ else if (lua_isstring (L, 2)) {
+ lua_getmetatable (L, 1);
+ lua_pushvalue (L, 2);
+ lua_rawget (L, -2);
+ }
+ }
+
+ return 1;
+}
+
/***
* @method tensor:mul(other, [transA, [transB]])
* Multiply two tensors (optionally transposed) and return a new tensor