diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-19 13:51:10 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-19 14:03:29 +0100 |
commit | 723294cbaad55a9d738adae263d347e95faca049 (patch) | |
tree | c0680b9d14654cde020b2f57775d53e6c07cdbe5 /src | |
parent | ad97a143fd5dae292d901402828e4fb059de0b7e (diff) | |
download | rspamd-723294cbaad55a9d738adae263d347e95faca049.tar.gz rspamd-723294cbaad55a9d738adae263d347e95faca049.zip |
[Minor] Fix tensor multiplication for the vectors case
Diffstat (limited to 'src')
-rw-r--r-- | src/lua/lua_tensor.c | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index cf91006d0..16bba985b 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -351,6 +351,7 @@ lua_tensor_index (lua_State *L) } } else if (lua_isstring (L, 2)) { + /* Access to methods */ lua_getmetatable (L, 1); lua_pushvalue (L, 2); lua_rawget (L, -2); @@ -392,7 +393,20 @@ lua_tensor_mul (lua_State *L) dims[0], shadow_dims[1], shadow_dims[0], dims[1]); } - res = lua_newtensor (L, 2, dims, true); + if (dims[0] == 0) { + /* Column */ + dims[0] = 1; + res = lua_newtensor (L, 2, dims, true, true); + } + else if (dims[1] == 0) { + /* Row */ + res = lua_newtensor (L, 1, dims, true, true); + dims[1] = 1; + } + else { + res = lua_newtensor (L, 2, dims, true, true); + } + kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0], t1->data, t2->data, res->data); } @@ -438,7 +452,7 @@ lua_tensor_load (lua_State *L) if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) { if (ndims == 1) { if (nelts == dims[0]) { - struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false); + struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true); memcpy (t->data, data + sizeof (int) * 4, nelts * sizeof (rspamd_tensor_num_t)); } @@ -449,7 +463,7 @@ lua_tensor_load (lua_State *L) } else if (ndims == 2) { if (nelts == dims[0] * dims[1]) { - struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false); + struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true); memcpy (t->data, data + sizeof (int) * 4, nelts * sizeof (rspamd_tensor_num_t)); } |