diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-05 21:04:32 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-05 21:05:09 +0100 |
commit | b8216839b2a9f083259b71947bf0caa4b4eef091 (patch) | |
tree | 12f672dbf121b02fa94c627e69d502217c05fb38 /src/lua | |
parent | 09fc651620d5d55949fb0ad60e402fb5e62668e2 (diff) | |
download | rspamd-b8216839b2a9f083259b71947bf0caa4b4eef091.tar.gz rspamd-b8216839b2a9f083259b71947bf0caa4b4eef091.zip |
[Project] Add tensors index method
Diffstat (limited to 'src/lua')
-rw-r--r-- | src/lua/lua_tensor.c | 50 |
1 files changed, 49 insertions, 1 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index 21bdf9673..85aaa2e95 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -32,6 +32,7 @@ LUA_FUNCTION_DEF (tensor, fromtable); LUA_FUNCTION_DEF (tensor, destroy); LUA_FUNCTION_DEF (tensor, mul); LUA_FUNCTION_DEF (tensor, tostring); +LUA_FUNCTION_DEF (tensor, index); static luaL_reg rspamd_tensor_f[] = { LUA_INTERFACE_DEF (tensor, load), @@ -45,8 +46,9 @@ static luaL_reg rspamd_tensor_m[] = { {"__gc", lua_tensor_destroy}, {"__mul", lua_tensor_mul}, {"mul", lua_tensor_mul}, - {"__tostring", lua_tensor_tostring}, {"tostring", lua_tensor_tostring}, + {"__tostring", lua_tensor_tostring}, + {"__index", lua_tensor_index}, {NULL, NULL}, }; @@ -284,6 +286,52 @@ lua_tensor_tostring (lua_State *L) return 1; } +static gint +lua_tensor_index (lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor (L, 1); + gint idx; + + if (t) { + if (lua_isnumber (L, 2)) { + idx = lua_tointeger (L, 2); + + if (t->ndims == 1) { + /* Individual element */ + if (idx <= t->dim[0]) { + lua_pushnumber (L, t->data[idx - 1]); + } + else { + lua_pushnil (L); + } + } + else { + /* Push row */ + gint dim = t->dim[1]; + + + if (idx <= t->dim[0]) { + struct rspamd_lua_tensor *res = + lua_newtensor (L, 1, &dim, false); + for (gint i = 0; i < dim; i++) { + res->data[i] = t->data[(idx - 1) * t->dim[1] + i]; + } + } + else { + lua_pushnil (L); + } + } + } + else if (lua_isstring (L, 2)) { + lua_getmetatable (L, 1); + lua_pushvalue (L, 2); + lua_rawget (L, -2); + } + } + + return 1; +} + /*** * @method tensor:mul(other, [transA, [transB]]) * Multiply two tensors (optionally transposed) and return a new tensor |