Browse Source

[Project] Add some missing functions to kann API

tags/2.0
Vsevolod Stakhov 5 years ago
parent
commit
c8c4280fee
2 changed files with 171 additions and 2 deletions
  1. 2
    0
      contrib/kann/kann.h
  2. 169
    2
      src/lua/lua_kann.c

+ 2
- 0
contrib/kann/kann.h View File

@@ -210,6 +210,8 @@ kad_node_t *kann_new_bias(int n);
kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col);
kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len);

kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM]);

kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...);
kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1);
kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r);

+ 169
- 2
src/lua/lua_kann.c View File

@@ -23,6 +23,7 @@
*/

#define KANN_NODE_CLASS "rspamd{kann_node}"
#define KANN_NETWORK_CLASS "rspamd{kann}"

/* Simple macros to define behaviour */
#define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
@@ -34,6 +35,10 @@
#define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
#define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}

#define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
#define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}


/*
* Forwarded declarations
*/
@@ -115,6 +120,26 @@ static luaL_reg rspamd_kann_loss_f[] = {
{NULL, NULL},
};

/* Creation functions */
KANN_NEW_DEF (leaf);
KANN_NEW_DEF (scalar);
KANN_NEW_DEF (weight);
KANN_NEW_DEF (bias);
KANN_NEW_DEF (weight_conv2d);
KANN_NEW_DEF (weight_conv1d);
KANN_NEW_DEF (kann);

static luaL_reg rspamd_kann_new_f[] = {
KANN_NEW_INTERFACE (leaf),
KANN_NEW_INTERFACE (scalar),
KANN_NEW_INTERFACE (weight),
KANN_NEW_INTERFACE (bias),
KANN_NEW_INTERFACE (weight_conv2d),
KANN_NEW_INTERFACE (weight_conv1d),
KANN_NEW_INTERFACE (kann),
{NULL, NULL},
};

static int
rspamd_kann_table_to_flags (lua_State *L, int table_pos)
{
@@ -196,6 +221,12 @@ lua_load_kann (lua_State * L)
luaL_register (L, NULL, rspamd_kann_loss_f);
lua_settable (L, -3);

/* Create functions */
lua_pushstring (L, "new");
lua_newtable (L);
luaL_register (L, NULL, rspamd_kann_new_f);
lua_settable (L, -3);

return 1;
}

@@ -210,7 +241,9 @@ lua_check_kann_node (lua_State *L, int pos)
void luaopen_kann (lua_State *L)
{
/* Metatables */
rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
lua_pop (L, 1); /* No need in metatable... */
rspamd_lua_new_class (L, KANN_NETWORK_CLASS, NULL); /* TODO: add methods */
lua_pop (L, 1); /* No need in metatable... */
rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
lua_settop (L, 0);
@@ -224,6 +257,13 @@ void luaopen_kann (lua_State *L)
rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
} while(0)

#define PUSH_KAN_NETWORK(n) do { \
kann_t **pn; \
pn = lua_newuserdata (L, sizeof (kann_t *)); \
*pn = (n); \
rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
} while(0)

#define PROCESS_KAD_FLAGS(n, pos) do { \
int fl = 0; \
if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
@@ -561,4 +601,131 @@ lua_kann_loss_ce_multi_weighted (lua_State *L)
}

return 1;
}
}

/* Creation functions */
static int
lua_kann_new_scalar (lua_State *L)
{
gint flag = luaL_checkinteger (L, 1);
double x = luaL_checknumber (L, 2);
kad_node_t *t;

t = kann_new_scalar (flag, x);

PROCESS_KAD_FLAGS (t, 3);
PUSH_KAD_NODE (t);

return 1;
}

static int
lua_kann_new_weight (lua_State *L)
{
gint nrow = luaL_checkinteger (L, 1);
gint ncol = luaL_checkinteger (L, 2);
kad_node_t *t;

t = kann_new_weight (nrow, ncol);

PROCESS_KAD_FLAGS (t, 3);
PUSH_KAD_NODE (t);

return 1;
}

static int
lua_kann_new_bias (lua_State *L)
{
gint n = luaL_checkinteger (L, 1);
kad_node_t *t;

t = kann_new_bias (n);

PROCESS_KAD_FLAGS (t, 2);
PUSH_KAD_NODE (t);

return 1;
}

static int
lua_kann_new_weight_conv2d (lua_State *L)
{
gint nout = luaL_checkinteger (L, 1);
gint nin = luaL_checkinteger (L, 2);
gint krow = luaL_checkinteger (L, 3);
gint kcol = luaL_checkinteger (L, 4);
kad_node_t *t;

t = kann_new_weight_conv2d (nout, nin, krow, kcol);

PROCESS_KAD_FLAGS (t, 5);
PUSH_KAD_NODE (t);

return 1;
}

static int
lua_kann_new_weight_conv1d (lua_State *L)
{
gint nout = luaL_checkinteger (L, 1);
gint nin = luaL_checkinteger (L, 2);
gint klen = luaL_checkinteger (L, 3);
kad_node_t *t;

t = kann_new_weight_conv1d (nout, nin, klen);

PROCESS_KAD_FLAGS (t, 4);
PUSH_KAD_NODE (t);

return 1;
}

static int
lua_kann_new_leaf (lua_State *L)
{
gint dim = luaL_checkinteger (L, 1), i, *ar;
kad_node_t *t;

if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
ar = g_malloc0 (sizeof (ar) * dim);

for (i = 0; i < dim; i ++) {
lua_rawgeti (L, 2, i + 1);
ar[i] = lua_tointeger (L, -1);
lua_pop (L, 1);
}

t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);

PROCESS_KAD_FLAGS (t, 3);
PUSH_KAD_NODE (t);

g_free (ar);
}
else {
return luaL_error (L, "invalid arguments for new.leaf, "
"dim and vector of elements are required");
}

return 1;
}

static int
lua_kann_new_kann (lua_State *L)
{
kad_node_t *cost = lua_check_kann_node (L, 1);
kann_t *k;

if (cost) {
k = kann_new (cost, 0);

PUSH_KAN_NETWORK (k);
}
else {
return luaL_error (L, "invalid arguments for new.kann, "
"cost node is required");
}

return 1;
}

Loading…
Cancel
Save