Browse Source

[Minor] Fix tensor projections multiplication

tags/2.6
Vsevolod Stakhov 3 years ago
parent
commit
50a60a54fd
1 changed files with 5 additions and 4 deletions
  1. 5
    4
      src/lua/lua_tensor.c

+ 5
- 4
src/lua/lua_tensor.c View File

@@ -60,6 +60,7 @@ lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own
struct rspamd_lua_tensor *res;

res = lua_newuserdata (L, sizeof (struct rspamd_lua_tensor));
memset (res, 0, sizeof (*res));

res->ndims = ndims;
res->size = 1;
@@ -453,10 +454,10 @@ lua_tensor_mul (lua_State *L)

if (t1 && t2) {
gint dims[2], shadow_dims[2];
dims[0] = transA ? t1->dim[1] : t1->dim[0];
shadow_dims[0] = transB ? t2->dim[1] : t2->dim[0];
dims[1] = transB ? t2->dim[0] : t2->dim[1];
shadow_dims[1] = transA ? t1->dim[0] : t1->dim[1];
dims[0] = abs (transA ? t1->dim[1] : t1->dim[0]);
shadow_dims[0] = abs (transB ? t2->dim[1] : t2->dim[0]);
dims[1] = abs (transB ? t2->dim[0] : t2->dim[1]);
shadow_dims[1] = abs (transA ? t1->dim[0] : t1->dim[1]);

if (shadow_dims[0] != shadow_dims[1]) {
return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",

Loading…
Cancel
Save