diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-21 16:11:32 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-21 21:48:46 +0100 |
commit | 71e58489aa8efbc883ba961b1f1cf15eebec3c87 (patch) | |
tree | 1c5491f6a17c69464330992946aeb9db48746f17 /src | |
parent | 85acd8afd9d4a38586c6c908dd30c8eb86758fce (diff) | |
download | rspamd-71e58489aa8efbc883ba961b1f1cf15eebec3c87.tar.gz rspamd-71e58489aa8efbc883ba961b1f1cf15eebec3c87.zip |
[Minor] Allow to use lua_tensor in kann apply
Diffstat (limited to 'src')
-rw-r--r-- | src/lua/lua_kann.c | 102 | ||||
-rw-r--r-- | src/lua/lua_tensor.c | 2 | ||||
-rw-r--r-- | src/lua/lua_tensor.h | 2 |
3 files changed, 74 insertions, 32 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index 33036fe04..1827fe1ac 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -15,6 +15,7 @@ */ #include "lua_common.h" +#include "lua_tensor.h" #include "contrib/kann/kann.h" /*** @@ -1155,48 +1156,87 @@ lua_kann_apply1 (lua_State *L) { kann_t *k = lua_check_kann (L, 1); - if (k && lua_istable (L, 2)) { - gsize vec_len = rspamd_lua_table_size (L, 2); - float *vec = (float *)g_malloc (sizeof (float) * vec_len); - int i_out; - int n_in = kann_dim_in (k); + if (k) { + if (lua_istable (L, 2)) { + gsize vec_len = rspamd_lua_table_size (L, 2); + float *vec = (float *) g_malloc (sizeof (float) * vec_len); + int i_out; + int n_in = kann_dim_in (k); - if (n_in <= 0) { - return luaL_error (L, "invalid inputs count: %d", n_in); - } + if (n_in <= 0) { + return luaL_error (L, "invalid inputs count: %d", n_in); + } - if (n_in != vec_len) { - return luaL_error (L, "invalid params: bad input dimension %d; %d expected", - (int)vec_len, n_in); - } + if (n_in != vec_len) { + return luaL_error (L, "invalid params: bad input dimension %d; %d expected", + (int) vec_len, n_in); + } - for (gsize i = 0; i < vec_len; i ++) { - lua_rawgeti (L, 2, i + 1); - vec[i] = lua_tonumber (L, -1); - lua_pop (L, 1); - } + for (gsize i = 0; i < vec_len; i++) { + lua_rawgeti (L, 2, i + 1); + vec[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } - i_out = kann_find (k, KANN_F_OUT, 0); + i_out = kann_find (k, KANN_F_OUT, 0); + + if (i_out <= 0) { + g_free (vec); + return luaL_error (L, "invalid ANN: output layer is missing or is " + "at the input pos"); + } + + kann_set_batch_size (k, 1); + kann_feed_bind (k, KANN_F_IN, 0, &vec); + kad_eval_at (k->n, k->v, i_out); + + gsize outlen = kad_len (k->v[i_out]); + lua_createtable (L, outlen, 0); + + for (gsize i = 0; i < outlen; i++) { + lua_pushnumber (L, k->v[i_out]->x[i]); + lua_rawseti (L, -2, i + 1); + } - if (i_out <= 0) { g_free (vec); - return luaL_error (L, "invalid ANN: output layer is missing or is " - "at the input pos"); } + else if (lua_isuserdata (L, 2)) { + struct rspamd_lua_tensor *t = lua_check_tensor (L, 2); - kann_set_batch_size (k, 1); - kann_feed_bind (k, KANN_F_IN, 0, &vec); - kad_eval_at (k->n, k->v, i_out); + if (t && t->ndims == 1) { + int i_out; + int n_in = kann_dim_in (k); - gsize outlen = kad_len (k->v[i_out]); - lua_createtable (L, outlen, 0); + if (n_in != t->dim[0]) { + return luaL_error (L, "invalid params: bad input dimension %d; %d expected", + (int) t->dim[0], n_in); + } - for (gsize i = 0; i < outlen; i ++) { - lua_pushnumber (L, k->v[i_out]->x[i]); - lua_rawseti (L, -2, i + 1); - } + i_out = kann_find (k, KANN_F_OUT, 0); + + if (i_out <= 0) { + return luaL_error (L, "invalid ANN: output layer is missing or is " + "at the input pos"); + } - g_free (vec); + kann_set_batch_size (k, 1); + kann_feed_bind (k, KANN_F_IN, 0, &t->data); + kad_eval_at (k->n, k->v, i_out); + + gint outlen = kad_len (k->v[i_out]); + struct rspamd_lua_tensor *out; + out = lua_newtensor (L, 1, &outlen, false, false); + /* Ensure that kann and tensor have the same understanding of floats */ + G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t)); + memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float)); + } + else { + return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected"); + } + } + else { + return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected"); + } } else { return luaL_error (L, "invalid arguments: rspamd{kann} expected"); diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c index 6e5bec7d8..1506d4548 100644 --- a/src/lua/lua_tensor.c +++ b/src/lua/lua_tensor.c @@ -54,7 +54,7 @@ static luaL_reg rspamd_tensor_m[] = { {NULL, NULL}, }; -static struct rspamd_lua_tensor * +struct rspamd_lua_tensor * lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own) { struct rspamd_lua_tensor *res; diff --git a/src/lua/lua_tensor.h b/src/lua/lua_tensor.h index e4c110011..e022f64b9 100644 --- a/src/lua/lua_tensor.h +++ b/src/lua/lua_tensor.h @@ -28,5 +28,7 @@ struct rspamd_lua_tensor { }; struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos); +struct rspamd_lua_tensor *lua_newtensor (lua_State *L, int ndims, + const int *dim, bool zero_fill, bool own); #endif |