aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-21 16:11:32 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-08-21 21:48:46 +0100
commit71e58489aa8efbc883ba961b1f1cf15eebec3c87 (patch)
tree1c5491f6a17c69464330992946aeb9db48746f17 /src
parent85acd8afd9d4a38586c6c908dd30c8eb86758fce (diff)
downloadrspamd-71e58489aa8efbc883ba961b1f1cf15eebec3c87.tar.gz
rspamd-71e58489aa8efbc883ba961b1f1cf15eebec3c87.zip
[Minor] Allow to use lua_tensor in kann apply
Diffstat (limited to 'src')
-rw-r--r--src/lua/lua_kann.c102
-rw-r--r--src/lua/lua_tensor.c2
-rw-r--r--src/lua/lua_tensor.h2
3 files changed, 74 insertions, 32 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index 33036fe04..1827fe1ac 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -15,6 +15,7 @@
*/
#include "lua_common.h"
+#include "lua_tensor.h"
#include "contrib/kann/kann.h"
/***
@@ -1155,48 +1156,87 @@ lua_kann_apply1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
- if (k && lua_istable (L, 2)) {
- gsize vec_len = rspamd_lua_table_size (L, 2);
- float *vec = (float *)g_malloc (sizeof (float) * vec_len);
- int i_out;
- int n_in = kann_dim_in (k);
+ if (k) {
+ if (lua_istable (L, 2)) {
+ gsize vec_len = rspamd_lua_table_size (L, 2);
+ float *vec = (float *) g_malloc (sizeof (float) * vec_len);
+ int i_out;
+ int n_in = kann_dim_in (k);
- if (n_in <= 0) {
- return luaL_error (L, "invalid inputs count: %d", n_in);
- }
+ if (n_in <= 0) {
+ return luaL_error (L, "invalid inputs count: %d", n_in);
+ }
- if (n_in != vec_len) {
- return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
- (int)vec_len, n_in);
- }
+ if (n_in != vec_len) {
+ return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+ (int) vec_len, n_in);
+ }
- for (gsize i = 0; i < vec_len; i ++) {
- lua_rawgeti (L, 2, i + 1);
- vec[i] = lua_tonumber (L, -1);
- lua_pop (L, 1);
- }
+ for (gsize i = 0; i < vec_len; i++) {
+ lua_rawgeti (L, 2, i + 1);
+ vec[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
- i_out = kann_find (k, KANN_F_OUT, 0);
+ i_out = kann_find (k, KANN_F_OUT, 0);
+
+ if (i_out <= 0) {
+ g_free (vec);
+ return luaL_error (L, "invalid ANN: output layer is missing or is "
+ "at the input pos");
+ }
+
+ kann_set_batch_size (k, 1);
+ kann_feed_bind (k, KANN_F_IN, 0, &vec);
+ kad_eval_at (k->n, k->v, i_out);
+
+ gsize outlen = kad_len (k->v[i_out]);
+ lua_createtable (L, outlen, 0);
+
+ for (gsize i = 0; i < outlen; i++) {
+ lua_pushnumber (L, k->v[i_out]->x[i]);
+ lua_rawseti (L, -2, i + 1);
+ }
- if (i_out <= 0) {
g_free (vec);
- return luaL_error (L, "invalid ANN: output layer is missing or is "
- "at the input pos");
}
+ else if (lua_isuserdata (L, 2)) {
+ struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);
- kann_set_batch_size (k, 1);
- kann_feed_bind (k, KANN_F_IN, 0, &vec);
- kad_eval_at (k->n, k->v, i_out);
+ if (t && t->ndims == 1) {
+ int i_out;
+ int n_in = kann_dim_in (k);
- gsize outlen = kad_len (k->v[i_out]);
- lua_createtable (L, outlen, 0);
+ if (n_in != t->dim[0]) {
+ return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+ (int) t->dim[0], n_in);
+ }
- for (gsize i = 0; i < outlen; i ++) {
- lua_pushnumber (L, k->v[i_out]->x[i]);
- lua_rawseti (L, -2, i + 1);
- }
+ i_out = kann_find (k, KANN_F_OUT, 0);
+
+ if (i_out <= 0) {
+ return luaL_error (L, "invalid ANN: output layer is missing or is "
+ "at the input pos");
+ }
- g_free (vec);
+ kann_set_batch_size (k, 1);
+ kann_feed_bind (k, KANN_F_IN, 0, &t->data);
+ kad_eval_at (k->n, k->v, i_out);
+
+ gint outlen = kad_len (k->v[i_out]);
+ struct rspamd_lua_tensor *out;
+ out = lua_newtensor (L, 1, &outlen, false, false);
+ /* Ensure that kann and tensor have the same understanding of floats */
+ G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t));
+ memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float));
+ }
+ else {
+ return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
+ }
+ }
+ else {
+ return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
+ }
}
else {
return luaL_error (L, "invalid arguments: rspamd{kann} expected");
diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c
index 6e5bec7d8..1506d4548 100644
--- a/src/lua/lua_tensor.c
+++ b/src/lua/lua_tensor.c
@@ -54,7 +54,7 @@ static luaL_reg rspamd_tensor_m[] = {
{NULL, NULL},
};
-static struct rspamd_lua_tensor *
+struct rspamd_lua_tensor *
lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
{
struct rspamd_lua_tensor *res;
diff --git a/src/lua/lua_tensor.h b/src/lua/lua_tensor.h
index e4c110011..e022f64b9 100644
--- a/src/lua/lua_tensor.h
+++ b/src/lua/lua_tensor.h
@@ -28,5 +28,7 @@ struct rspamd_lua_tensor {
};
struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos);
+struct rspamd_lua_tensor *lua_newtensor (lua_State *L, int ndims,
+ const int *dim, bool zero_fill, bool own);
#endif