]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Allow to use lua_tensor in kann apply
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 21 Aug 2020 15:11:32 +0000 (16:11 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 21 Aug 2020 20:48:46 +0000 (21:48 +0100)
src/lua/lua_kann.c
src/lua/lua_tensor.c
src/lua/lua_tensor.h

index 33036fe0452c3b3386bc27df905164f2ea064d95..1827fe1acfa2d528a2cd83baa368e92a4539442d 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include "lua_common.h"
+#include "lua_tensor.h"
 #include "contrib/kann/kann.h"
 
 /***
@@ -1155,48 +1156,87 @@ lua_kann_apply1 (lua_State *L)
 {
        kann_t *k = lua_check_kann (L, 1);
 
-       if (k && lua_istable (L, 2)) {
-               gsize vec_len = rspamd_lua_table_size (L, 2);
-               float *vec = (float *)g_malloc (sizeof (float) * vec_len);
-               int i_out;
-               int n_in = kann_dim_in (k);
+       if (k) {
+               if (lua_istable (L, 2)) {
+                       gsize vec_len = rspamd_lua_table_size (L, 2);
+                       float *vec = (float *) g_malloc (sizeof (float) * vec_len);
+                       int i_out;
+                       int n_in = kann_dim_in (k);
 
-               if (n_in <= 0) {
-                       return luaL_error (L, "invalid inputs count: %d", n_in);
-               }
+                       if (n_in <= 0) {
+                               return luaL_error (L, "invalid inputs count: %d", n_in);
+                       }
 
-               if (n_in != vec_len) {
-                       return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
-                                       (int)vec_len, n_in);
-               }
+                       if (n_in != vec_len) {
+                               return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+                                               (int) vec_len, n_in);
+                       }
 
-               for (gsize i = 0; i < vec_len; i ++) {
-                       lua_rawgeti (L, 2, i + 1);
-                       vec[i] = lua_tonumber (L, -1);
-                       lua_pop (L, 1);
-               }
+                       for (gsize i = 0; i < vec_len; i++) {
+                               lua_rawgeti (L, 2, i + 1);
+                               vec[i] = lua_tonumber (L, -1);
+                               lua_pop (L, 1);
+                       }
 
-               i_out = kann_find (k, KANN_F_OUT, 0);
+                       i_out = kann_find (k, KANN_F_OUT, 0);
+
+                       if (i_out <= 0) {
+                               g_free (vec);
+                               return luaL_error (L, "invalid ANN: output layer is missing or is "
+                                                                         "at the input pos");
+                       }
+
+                       kann_set_batch_size (k, 1);
+                       kann_feed_bind (k, KANN_F_IN, 0, &vec);
+                       kad_eval_at (k->n, k->v, i_out);
+
+                       gsize outlen = kad_len (k->v[i_out]);
+                       lua_createtable (L, outlen, 0);
+
+                       for (gsize i = 0; i < outlen; i++) {
+                               lua_pushnumber (L, k->v[i_out]->x[i]);
+                               lua_rawseti (L, -2, i + 1);
+                       }
 
-               if (i_out <= 0) {
                        g_free (vec);
-                       return luaL_error (L, "invalid ANN: output layer is missing or is "
-                                                "at the input pos");
                }
+               else if (lua_isuserdata (L, 2)) {
+                       struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);
 
-               kann_set_batch_size (k, 1);
-               kann_feed_bind (k, KANN_F_IN, 0, &vec);
-               kad_eval_at (k->n, k->v, i_out);
+                       if (t && t->ndims == 1) {
+                               int i_out;
+                               int n_in = kann_dim_in (k);
 
-               gsize outlen = kad_len (k->v[i_out]);
-               lua_createtable (L, outlen, 0);
+                               if (n_in != t->dim[0]) {
+                                       return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+                                                       (int) t->dim[0], n_in);
+                               }
 
-               for (gsize i = 0; i < outlen; i ++) {
-                       lua_pushnumber (L, k->v[i_out]->x[i]);
-                       lua_rawseti (L, -2, i + 1);
-               }
+                               i_out = kann_find (k, KANN_F_OUT, 0);
+
+                               if (i_out <= 0) {
+                                       return luaL_error (L, "invalid ANN: output layer is missing or is "
+                                                                                 "at the input pos");
+                               }
 
-               g_free (vec);
+                               kann_set_batch_size (k, 1);
+                               kann_feed_bind (k, KANN_F_IN, 0, &t->data);
+                               kad_eval_at (k->n, k->v, i_out);
+
+                               gint outlen = kad_len (k->v[i_out]);
+                               struct rspamd_lua_tensor *out;
+                               out = lua_newtensor (L, 1, &outlen, false, false);
+                               /* Ensure that kann and tensor have the same understanding of floats */
+                               G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t));
+                               memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float));
+                       }
+                       else {
+                               return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
+                       }
+               }
+               else {
+                       return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
+               }
        }
        else {
                return luaL_error (L, "invalid arguments: rspamd{kann} expected");
index 6e5bec7d8f1dad17e780c795ad69420d90434dc4..1506d45488390ec33c9465e8dab07a5cadce3146 100644 (file)
@@ -54,7 +54,7 @@ static luaL_reg rspamd_tensor_m[] = {
                {NULL, NULL},
 };
 
-static struct rspamd_lua_tensor *
+struct rspamd_lua_tensor *
 lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
 {
        struct rspamd_lua_tensor *res;
index e4c110011d4c0e8ff8ac64590a5d4cad1c79937a..e022f64b977773360c9b1d1437b035027c22edb6 100644 (file)
@@ -28,5 +28,7 @@ struct rspamd_lua_tensor {
 };
 
 struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos);
+struct rspamd_lua_tensor *lua_newtensor (lua_State *L, int ndims,
+               const int *dim, bool zero_fill, bool own);
 
 #endif