]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add some missing functions to kann API
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 29 Jun 2019 11:35:00 +0000 (12:35 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 29 Jun 2019 11:35:00 +0000 (12:35 +0100)
contrib/kann/kann.h
src/lua/lua_kann.c

index 1605e5ea593f56f189da2131ad9cf4d5cfbac946..7ec748561fd74d5d8a6dc6d2656cde1d15c51f0a 100644 (file)
@@ -210,6 +210,8 @@ kad_node_t *kann_new_bias(int n);
 kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col);
 kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len);
 
+kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM]);
+
 kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...);
 kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1);
 kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r);
index cec75acbb6d4a24a734a4c3503b401f82d84e22f..171c814549ec53b8a4d9805007e84f0dc81e9a14 100644 (file)
@@ -23,6 +23,7 @@
  */
 
 #define KANN_NODE_CLASS "rspamd{kann_node}"
+#define KANN_NETWORK_CLASS "rspamd{kann}"
 
 /* Simple macros to define behaviour */
 #define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
 #define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
 #define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
 
+#define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
+#define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}
+
+
 /*
  * Forwarded declarations
  */
@@ -115,6 +120,26 @@ static luaL_reg rspamd_kann_loss_f[] = {
                {NULL, NULL},
 };
 
+/* Creation functions */
+KANN_NEW_DEF (leaf);
+KANN_NEW_DEF (scalar);
+KANN_NEW_DEF (weight);
+KANN_NEW_DEF (bias);
+KANN_NEW_DEF (weight_conv2d);
+KANN_NEW_DEF (weight_conv1d);
+KANN_NEW_DEF (kann);
+
+static luaL_reg rspamd_kann_new_f[] = {
+               KANN_NEW_INTERFACE (leaf),
+               KANN_NEW_INTERFACE (scalar),
+               KANN_NEW_INTERFACE (weight),
+               KANN_NEW_INTERFACE (bias),
+               KANN_NEW_INTERFACE (weight_conv2d),
+               KANN_NEW_INTERFACE (weight_conv1d),
+               KANN_NEW_INTERFACE (kann),
+               {NULL, NULL},
+};
+
 static int
 rspamd_kann_table_to_flags (lua_State *L, int table_pos)
 {
@@ -196,6 +221,12 @@ lua_load_kann (lua_State * L)
        luaL_register (L, NULL, rspamd_kann_loss_f);
        lua_settable (L, -3);
 
+       /* Create functions */
+       lua_pushstring (L, "new");
+       lua_newtable (L);
+       luaL_register (L, NULL, rspamd_kann_new_f);
+       lua_settable (L, -3);
+
        return 1;
 }
 
@@ -210,7 +241,9 @@ lua_check_kann_node (lua_State *L, int pos)
 void luaopen_kann (lua_State *L)
 {
        /* Metatables */
-       rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
+       rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
+       lua_pop (L, 1); /* No need in metatable... */
+       rspamd_lua_new_class (L, KANN_NETWORK_CLASS, NULL); /* TODO: add methods */
        lua_pop (L, 1); /* No need in metatable... */
        rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
        lua_settop (L, 0);
@@ -224,6 +257,13 @@ void luaopen_kann (lua_State *L)
        rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
 } while(0)
 
+#define PUSH_KAN_NETWORK(n) do { \
+       kann_t **pn; \
+       pn = lua_newuserdata (L, sizeof (kann_t *)); \
+       *pn = (n); \
+       rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
+} while(0)
+
 #define PROCESS_KAD_FLAGS(n, pos) do { \
        int fl = 0; \
        if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
@@ -561,4 +601,131 @@ lua_kann_loss_ce_multi_weighted (lua_State *L)
        }
 
        return 1;
-}
\ No newline at end of file
+}
+
+/* Creation functions */
+static int
+lua_kann_new_scalar (lua_State *L)
+{
+       gint flag = luaL_checkinteger (L, 1);
+       double x = luaL_checknumber (L, 2);
+       kad_node_t *t;
+
+       t = kann_new_scalar (flag, x);
+
+       PROCESS_KAD_FLAGS (t, 3);
+       PUSH_KAD_NODE (t);
+
+       return 1;
+}
+
+static int
+lua_kann_new_weight (lua_State *L)
+{
+       gint nrow = luaL_checkinteger (L, 1);
+       gint ncol = luaL_checkinteger (L, 2);
+       kad_node_t *t;
+
+       t = kann_new_weight (nrow, ncol);
+
+       PROCESS_KAD_FLAGS (t, 3);
+       PUSH_KAD_NODE (t);
+
+       return 1;
+}
+
+static int
+lua_kann_new_bias (lua_State *L)
+{
+       gint n = luaL_checkinteger (L, 1);
+       kad_node_t *t;
+
+       t = kann_new_bias (n);
+
+       PROCESS_KAD_FLAGS (t, 2);
+       PUSH_KAD_NODE (t);
+
+       return 1;
+}
+
+static int
+lua_kann_new_weight_conv2d (lua_State *L)
+{
+       gint nout = luaL_checkinteger (L, 1);
+       gint nin = luaL_checkinteger (L, 2);
+       gint krow = luaL_checkinteger (L, 3);
+       gint kcol = luaL_checkinteger (L, 4);
+       kad_node_t *t;
+
+       t = kann_new_weight_conv2d (nout, nin, krow, kcol);
+
+       PROCESS_KAD_FLAGS (t, 5);
+       PUSH_KAD_NODE (t);
+
+       return 1;
+}
+
+static int
+lua_kann_new_weight_conv1d (lua_State *L)
+{
+       gint nout = luaL_checkinteger (L, 1);
+       gint nin = luaL_checkinteger (L, 2);
+       gint klen = luaL_checkinteger (L, 3);
+       kad_node_t *t;
+
+       t = kann_new_weight_conv1d (nout, nin, klen);
+
+       PROCESS_KAD_FLAGS (t, 4);
+       PUSH_KAD_NODE (t);
+
+       return 1;
+}
+
+static int
+lua_kann_new_leaf (lua_State *L)
+{
+       gint dim = luaL_checkinteger (L, 1), i, *ar;
+       kad_node_t *t;
+
+       if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
+               ar = g_malloc0 (sizeof (ar) * dim);
+
+               for (i = 0; i < dim; i ++) {
+                       lua_rawgeti (L, 2, i + 1);
+                       ar[i] = lua_tointeger (L, -1);
+                       lua_pop (L, 1);
+               }
+
+               t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);
+
+               PROCESS_KAD_FLAGS (t, 3);
+               PUSH_KAD_NODE (t);
+
+               g_free (ar);
+       }
+       else {
+               return luaL_error (L, "invalid arguments for new.leaf, "
+                                               "dim and vector of elements are required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_new_kann (lua_State *L)
+{
+       kad_node_t *cost = lua_check_kann_node (L, 1);
+       kann_t *k;
+
+       if (cost) {
+               k = kann_new (cost, 0);
+
+               PUSH_KAN_NETWORK (k);
+       }
+       else {
+               return luaL_error (L, "invalid arguments for new.kann, "
+                                                         "cost node is required");
+       }
+
+       return 1;
+}