From c8c4280fee8cc896a98dccb26626b7853a18aa7d Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 29 Jun 2019 12:35:00 +0100 Subject: [PATCH] [Project] Add some missing functions to kann API --- contrib/kann/kann.h | 2 + src/lua/lua_kann.c | 171 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 2 deletions(-) diff --git a/contrib/kann/kann.h b/contrib/kann/kann.h index 1605e5ea5..7ec748561 100644 --- a/contrib/kann/kann.h +++ b/contrib/kann/kann.h @@ -210,6 +210,8 @@ kad_node_t *kann_new_bias(int n); kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col); kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len); +kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM]); + kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...); kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1); kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r); diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index cec75acbb..171c81454 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -23,6 +23,7 @@ */ #define KANN_NODE_CLASS "rspamd{kann_node}" +#define KANN_NETWORK_CLASS "rspamd{kann}" /* Simple macros to define behaviour */ #define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L) @@ -34,6 +35,10 @@ #define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L) #define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name} +#define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L) +#define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name} + + /* * Forwarded declarations */ @@ -115,6 +120,26 @@ static luaL_reg rspamd_kann_loss_f[] = { {NULL, NULL}, }; +/* Creation functions */ +KANN_NEW_DEF (leaf); +KANN_NEW_DEF (scalar); +KANN_NEW_DEF (weight); +KANN_NEW_DEF (bias); +KANN_NEW_DEF (weight_conv2d); +KANN_NEW_DEF (weight_conv1d); +KANN_NEW_DEF (kann); + +static luaL_reg rspamd_kann_new_f[] = { + KANN_NEW_INTERFACE (leaf), + KANN_NEW_INTERFACE (scalar), + KANN_NEW_INTERFACE (weight), + KANN_NEW_INTERFACE (bias), + KANN_NEW_INTERFACE (weight_conv2d), + KANN_NEW_INTERFACE (weight_conv1d), + KANN_NEW_INTERFACE (kann), + {NULL, NULL}, +}; + static int rspamd_kann_table_to_flags (lua_State *L, int table_pos) { @@ -196,6 +221,12 @@ lua_load_kann (lua_State * L) luaL_register (L, NULL, rspamd_kann_loss_f); lua_settable (L, -3); + /* Create functions */ + lua_pushstring (L, "new"); + lua_newtable (L); + luaL_register (L, NULL, rspamd_kann_new_f); + lua_settable (L, -3); + return 1; } @@ -210,7 +241,9 @@ lua_check_kann_node (lua_State *L, int pos) void luaopen_kann (lua_State *L) { /* Metatables */ - rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); + rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */ + lua_pop (L, 1); /* No need in metatable... */ + rspamd_lua_new_class (L, KANN_NETWORK_CLASS, NULL); /* TODO: add methods */ lua_pop (L, 1); /* No need in metatable... */ rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann); lua_settop (L, 0); @@ -224,6 +257,13 @@ void luaopen_kann (lua_State *L) rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \ } while(0) +#define PUSH_KAN_NETWORK(n) do { \ + kann_t **pn; \ + pn = lua_newuserdata (L, sizeof (kann_t *)); \ + *pn = (n); \ + rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \ +} while(0) + #define PROCESS_KAD_FLAGS(n, pos) do { \ int fl = 0; \ if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \ @@ -561,4 +601,131 @@ lua_kann_loss_ce_multi_weighted (lua_State *L) } return 1; -} \ No newline at end of file +} + +/* Creation functions */ +static int +lua_kann_new_scalar (lua_State *L) +{ + gint flag = luaL_checkinteger (L, 1); + double x = luaL_checknumber (L, 2); + kad_node_t *t; + + t = kann_new_scalar (flag, x); + + PROCESS_KAD_FLAGS (t, 3); + PUSH_KAD_NODE (t); + + return 1; +} + +static int +lua_kann_new_weight (lua_State *L) +{ + gint nrow = luaL_checkinteger (L, 1); + gint ncol = luaL_checkinteger (L, 2); + kad_node_t *t; + + t = kann_new_weight (nrow, ncol); + + PROCESS_KAD_FLAGS (t, 3); + PUSH_KAD_NODE (t); + + return 1; +} + +static int +lua_kann_new_bias (lua_State *L) +{ + gint n = luaL_checkinteger (L, 1); + kad_node_t *t; + + t = kann_new_bias (n); + + PROCESS_KAD_FLAGS (t, 2); + PUSH_KAD_NODE (t); + + return 1; +} + +static int +lua_kann_new_weight_conv2d (lua_State *L) +{ + gint nout = luaL_checkinteger (L, 1); + gint nin = luaL_checkinteger (L, 2); + gint krow = luaL_checkinteger (L, 3); + gint kcol = luaL_checkinteger (L, 4); + kad_node_t *t; + + t = kann_new_weight_conv2d (nout, nin, krow, kcol); + + PROCESS_KAD_FLAGS (t, 5); + PUSH_KAD_NODE (t); + + return 1; +} + +static int +lua_kann_new_weight_conv1d (lua_State *L) +{ + gint nout = luaL_checkinteger (L, 1); + gint nin = luaL_checkinteger (L, 2); + gint klen = luaL_checkinteger (L, 3); + kad_node_t *t; + + t = kann_new_weight_conv1d (nout, nin, klen); + + PROCESS_KAD_FLAGS (t, 4); + PUSH_KAD_NODE (t); + + return 1; +} + +static int +lua_kann_new_leaf (lua_State *L) +{ + gint dim = luaL_checkinteger (L, 1), i, *ar; + kad_node_t *t; + + if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) { + ar = g_malloc0 (sizeof (ar) * dim); + + for (i = 0; i < dim; i ++) { + lua_rawgeti (L, 2, i + 1); + ar[i] = lua_tointeger (L, -1); + lua_pop (L, 1); + } + + t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar); + + PROCESS_KAD_FLAGS (t, 3); + PUSH_KAD_NODE (t); + + g_free (ar); + } + else { + return luaL_error (L, "invalid arguments for new.leaf, " + "dim and vector of elements are required"); + } + + return 1; +} + +static int +lua_kann_new_kann (lua_State *L) +{ + kad_node_t *cost = lua_check_kann_node (L, 1); + kann_t *k; + + if (cost) { + k = kann_new (cost, 0); + + PUSH_KAN_NETWORK (k); + } + else { + return luaL_error (L, "invalid arguments for new.kann, " + "cost node is required"); + } + + return 1; +} -- 2.39.5