Browse Source

[Minor] Allow to use lua_tensor in kann apply

tags/2.6
Vsevolod Stakhov 3 years ago
parent
commit
71e58489aa
3 changed files with 74 additions and 32 deletions
  1. 71
    31
      src/lua/lua_kann.c
  2. 1
    1
      src/lua/lua_tensor.c
  3. 2
    0
      src/lua/lua_tensor.h

+ 71
- 31
src/lua/lua_kann.c View File

@@ -15,6 +15,7 @@
*/

#include "lua_common.h"
#include "lua_tensor.h"
#include "contrib/kann/kann.h"

/***
@@ -1155,48 +1156,87 @@ lua_kann_apply1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);

if (k && lua_istable (L, 2)) {
gsize vec_len = rspamd_lua_table_size (L, 2);
float *vec = (float *)g_malloc (sizeof (float) * vec_len);
int i_out;
int n_in = kann_dim_in (k);
if (k) {
if (lua_istable (L, 2)) {
gsize vec_len = rspamd_lua_table_size (L, 2);
float *vec = (float *) g_malloc (sizeof (float) * vec_len);
int i_out;
int n_in = kann_dim_in (k);

if (n_in <= 0) {
return luaL_error (L, "invalid inputs count: %d", n_in);
}
if (n_in <= 0) {
return luaL_error (L, "invalid inputs count: %d", n_in);
}

if (n_in != vec_len) {
return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
(int)vec_len, n_in);
}
if (n_in != vec_len) {
return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
(int) vec_len, n_in);
}

for (gsize i = 0; i < vec_len; i ++) {
lua_rawgeti (L, 2, i + 1);
vec[i] = lua_tonumber (L, -1);
lua_pop (L, 1);
}
for (gsize i = 0; i < vec_len; i++) {
lua_rawgeti (L, 2, i + 1);
vec[i] = lua_tonumber (L, -1);
lua_pop (L, 1);
}

i_out = kann_find (k, KANN_F_OUT, 0);
i_out = kann_find (k, KANN_F_OUT, 0);

if (i_out <= 0) {
g_free (vec);
return luaL_error (L, "invalid ANN: output layer is missing or is "
"at the input pos");
}

kann_set_batch_size (k, 1);
kann_feed_bind (k, KANN_F_IN, 0, &vec);
kad_eval_at (k->n, k->v, i_out);

gsize outlen = kad_len (k->v[i_out]);
lua_createtable (L, outlen, 0);

for (gsize i = 0; i < outlen; i++) {
lua_pushnumber (L, k->v[i_out]->x[i]);
lua_rawseti (L, -2, i + 1);
}

if (i_out <= 0) {
g_free (vec);
return luaL_error (L, "invalid ANN: output layer is missing or is "
"at the input pos");
}
else if (lua_isuserdata (L, 2)) {
struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);

kann_set_batch_size (k, 1);
kann_feed_bind (k, KANN_F_IN, 0, &vec);
kad_eval_at (k->n, k->v, i_out);
if (t && t->ndims == 1) {
int i_out;
int n_in = kann_dim_in (k);

gsize outlen = kad_len (k->v[i_out]);
lua_createtable (L, outlen, 0);
if (n_in != t->dim[0]) {
return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
(int) t->dim[0], n_in);
}

for (gsize i = 0; i < outlen; i ++) {
lua_pushnumber (L, k->v[i_out]->x[i]);
lua_rawseti (L, -2, i + 1);
}
i_out = kann_find (k, KANN_F_OUT, 0);

if (i_out <= 0) {
return luaL_error (L, "invalid ANN: output layer is missing or is "
"at the input pos");
}

g_free (vec);
kann_set_batch_size (k, 1);
kann_feed_bind (k, KANN_F_IN, 0, &t->data);
kad_eval_at (k->n, k->v, i_out);

gint outlen = kad_len (k->v[i_out]);
struct rspamd_lua_tensor *out;
out = lua_newtensor (L, 1, &outlen, false, false);
/* Ensure that kann and tensor have the same understanding of floats */
G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t));
memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float));
}
else {
return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
}
}
else {
return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
}
}
else {
return luaL_error (L, "invalid arguments: rspamd{kann} expected");

+ 1
- 1
src/lua/lua_tensor.c View File

@@ -54,7 +54,7 @@ static luaL_reg rspamd_tensor_m[] = {
{NULL, NULL},
};

static struct rspamd_lua_tensor *
struct rspamd_lua_tensor *
lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
{
struct rspamd_lua_tensor *res;

+ 2
- 0
src/lua/lua_tensor.h View File

@@ -28,5 +28,7 @@ struct rspamd_lua_tensor {
};

struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos);
struct rspamd_lua_tensor *lua_newtensor (lua_State *L, int ndims,
const int *dim, bool zero_fill, bool own);

#endif

Loading…
Cancel
Save