aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-19 13:51:10 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-19 14:03:29 +0100
commit723294cbaad55a9d738adae263d347e95faca049 (patch)
treec0680b9d14654cde020b2f57775d53e6c07cdbe5 /src
parentad97a143fd5dae292d901402828e4fb059de0b7e (diff)
downloadrspamd-723294cbaad55a9d738adae263d347e95faca049.tar.gz
rspamd-723294cbaad55a9d738adae263d347e95faca049.zip
[Minor] Fix tensor multiplication for the vectors case
Diffstat (limited to 'src')
-rw-r--r--src/lua/lua_tensor.c20
1 files changed, 17 insertions, 3 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index cf91006d0..16bba985b 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -351,6 +351,7 @@ lua_tensor_index (lua_State *L)
}
}
else if (lua_isstring (L, 2)) {
+ /* Access to methods */
lua_getmetatable (L, 1);
lua_pushvalue (L, 2);
lua_rawget (L, -2);
@@ -392,7 +393,20 @@ lua_tensor_mul (lua_State *L)
dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
}
- res = lua_newtensor (L, 2, dims, true);
+ if (dims[0] == 0) {
+ /* Column */
+ dims[0] = 1;
+ res = lua_newtensor (L, 2, dims, true, true);
+ }
+ else if (dims[1] == 0) {
+ /* Row */
+ res = lua_newtensor (L, 1, dims, true, true);
+ dims[1] = 1;
+ }
+ else {
+ res = lua_newtensor (L, 2, dims, true, true);
+ }
+
kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
t1->data, t2->data, res->data);
}
@@ -438,7 +452,7 @@ lua_tensor_load (lua_State *L)
if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
if (ndims == 1) {
if (nelts == dims[0]) {
- struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+ struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}
@@ -449,7 +463,7 @@ lua_tensor_load (lua_State *L)
}
else if (ndims == 2) {
if (nelts == dims[0] * dims[1]) {
- struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+ struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}