|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "lua_common.h" |
|
|
|
#include "lua_tensor.h" |
|
|
|
#include "contrib/kann/kann.h" |
|
|
|
|
|
|
|
/*** |
|
|
@@ -1155,48 +1156,87 @@ lua_kann_apply1 (lua_State *L) |
|
|
|
{ |
|
|
|
kann_t *k = lua_check_kann (L, 1); |
|
|
|
|
|
|
|
if (k && lua_istable (L, 2)) { |
|
|
|
gsize vec_len = rspamd_lua_table_size (L, 2); |
|
|
|
float *vec = (float *)g_malloc (sizeof (float) * vec_len); |
|
|
|
int i_out; |
|
|
|
int n_in = kann_dim_in (k); |
|
|
|
if (k) { |
|
|
|
if (lua_istable (L, 2)) { |
|
|
|
gsize vec_len = rspamd_lua_table_size (L, 2); |
|
|
|
float *vec = (float *) g_malloc (sizeof (float) * vec_len); |
|
|
|
int i_out; |
|
|
|
int n_in = kann_dim_in (k); |
|
|
|
|
|
|
|
if (n_in <= 0) { |
|
|
|
return luaL_error (L, "invalid inputs count: %d", n_in); |
|
|
|
} |
|
|
|
if (n_in <= 0) { |
|
|
|
return luaL_error (L, "invalid inputs count: %d", n_in); |
|
|
|
} |
|
|
|
|
|
|
|
if (n_in != vec_len) { |
|
|
|
return luaL_error (L, "invalid params: bad input dimension %d; %d expected", |
|
|
|
(int)vec_len, n_in); |
|
|
|
} |
|
|
|
if (n_in != vec_len) { |
|
|
|
return luaL_error (L, "invalid params: bad input dimension %d; %d expected", |
|
|
|
(int) vec_len, n_in); |
|
|
|
} |
|
|
|
|
|
|
|
for (gsize i = 0; i < vec_len; i ++) { |
|
|
|
lua_rawgeti (L, 2, i + 1); |
|
|
|
vec[i] = lua_tonumber (L, -1); |
|
|
|
lua_pop (L, 1); |
|
|
|
} |
|
|
|
for (gsize i = 0; i < vec_len; i++) { |
|
|
|
lua_rawgeti (L, 2, i + 1); |
|
|
|
vec[i] = lua_tonumber (L, -1); |
|
|
|
lua_pop (L, 1); |
|
|
|
} |
|
|
|
|
|
|
|
i_out = kann_find (k, KANN_F_OUT, 0); |
|
|
|
i_out = kann_find (k, KANN_F_OUT, 0); |
|
|
|
|
|
|
|
if (i_out <= 0) { |
|
|
|
g_free (vec); |
|
|
|
return luaL_error (L, "invalid ANN: output layer is missing or is " |
|
|
|
"at the input pos"); |
|
|
|
} |
|
|
|
|
|
|
|
kann_set_batch_size (k, 1); |
|
|
|
kann_feed_bind (k, KANN_F_IN, 0, &vec); |
|
|
|
kad_eval_at (k->n, k->v, i_out); |
|
|
|
|
|
|
|
gsize outlen = kad_len (k->v[i_out]); |
|
|
|
lua_createtable (L, outlen, 0); |
|
|
|
|
|
|
|
for (gsize i = 0; i < outlen; i++) { |
|
|
|
lua_pushnumber (L, k->v[i_out]->x[i]); |
|
|
|
lua_rawseti (L, -2, i + 1); |
|
|
|
} |
|
|
|
|
|
|
|
if (i_out <= 0) { |
|
|
|
g_free (vec); |
|
|
|
return luaL_error (L, "invalid ANN: output layer is missing or is " |
|
|
|
"at the input pos"); |
|
|
|
} |
|
|
|
else if (lua_isuserdata (L, 2)) { |
|
|
|
struct rspamd_lua_tensor *t = lua_check_tensor (L, 2); |
|
|
|
|
|
|
|
kann_set_batch_size (k, 1); |
|
|
|
kann_feed_bind (k, KANN_F_IN, 0, &vec); |
|
|
|
kad_eval_at (k->n, k->v, i_out); |
|
|
|
if (t && t->ndims == 1) { |
|
|
|
int i_out; |
|
|
|
int n_in = kann_dim_in (k); |
|
|
|
|
|
|
|
gsize outlen = kad_len (k->v[i_out]); |
|
|
|
lua_createtable (L, outlen, 0); |
|
|
|
if (n_in != t->dim[0]) { |
|
|
|
return luaL_error (L, "invalid params: bad input dimension %d; %d expected", |
|
|
|
(int) t->dim[0], n_in); |
|
|
|
} |
|
|
|
|
|
|
|
for (gsize i = 0; i < outlen; i ++) { |
|
|
|
lua_pushnumber (L, k->v[i_out]->x[i]); |
|
|
|
lua_rawseti (L, -2, i + 1); |
|
|
|
} |
|
|
|
i_out = kann_find (k, KANN_F_OUT, 0); |
|
|
|
|
|
|
|
if (i_out <= 0) { |
|
|
|
return luaL_error (L, "invalid ANN: output layer is missing or is " |
|
|
|
"at the input pos"); |
|
|
|
} |
|
|
|
|
|
|
|
g_free (vec); |
|
|
|
kann_set_batch_size (k, 1); |
|
|
|
kann_feed_bind (k, KANN_F_IN, 0, &t->data); |
|
|
|
kad_eval_at (k->n, k->v, i_out); |
|
|
|
|
|
|
|
gint outlen = kad_len (k->v[i_out]); |
|
|
|
struct rspamd_lua_tensor *out; |
|
|
|
out = lua_newtensor (L, 1, &outlen, false, false); |
|
|
|
/* Ensure that kann and tensor have the same understanding of floats */ |
|
|
|
G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t)); |
|
|
|
memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float)); |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected"); |
|
|
|
} |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected"); |
|
|
|
} |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid arguments: rspamd{kann} expected"); |