|
|
@@ -23,6 +23,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#define KANN_NODE_CLASS "rspamd{kann_node}" |
|
|
|
#define KANN_NETWORK_CLASS "rspamd{kann}" |
|
|
|
|
|
|
|
/* Simple macros to define behaviour */ |
|
|
|
#define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L) |
|
|
@@ -34,6 +35,10 @@ |
|
|
|
#define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L) |
|
|
|
#define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name} |
|
|
|
|
|
|
|
#define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L) |
|
|
|
#define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name} |
|
|
|
|
|
|
|
|
|
|
|
/* |
|
|
|
* Forwarded declarations |
|
|
|
*/ |
|
|
@@ -115,6 +120,26 @@ static luaL_reg rspamd_kann_loss_f[] = { |
|
|
|
{NULL, NULL}, |
|
|
|
}; |
|
|
|
|
|
|
|
/* Creation functions */ |
|
|
|
KANN_NEW_DEF (leaf); |
|
|
|
KANN_NEW_DEF (scalar); |
|
|
|
KANN_NEW_DEF (weight); |
|
|
|
KANN_NEW_DEF (bias); |
|
|
|
KANN_NEW_DEF (weight_conv2d); |
|
|
|
KANN_NEW_DEF (weight_conv1d); |
|
|
|
KANN_NEW_DEF (kann); |
|
|
|
|
|
|
|
static luaL_reg rspamd_kann_new_f[] = { |
|
|
|
KANN_NEW_INTERFACE (leaf), |
|
|
|
KANN_NEW_INTERFACE (scalar), |
|
|
|
KANN_NEW_INTERFACE (weight), |
|
|
|
KANN_NEW_INTERFACE (bias), |
|
|
|
KANN_NEW_INTERFACE (weight_conv2d), |
|
|
|
KANN_NEW_INTERFACE (weight_conv1d), |
|
|
|
KANN_NEW_INTERFACE (kann), |
|
|
|
{NULL, NULL}, |
|
|
|
}; |
|
|
|
|
|
|
|
static int |
|
|
|
rspamd_kann_table_to_flags (lua_State *L, int table_pos) |
|
|
|
{ |
|
|
@@ -196,6 +221,12 @@ lua_load_kann (lua_State * L) |
|
|
|
luaL_register (L, NULL, rspamd_kann_loss_f); |
|
|
|
lua_settable (L, -3); |
|
|
|
|
|
|
|
/* Create functions */ |
|
|
|
lua_pushstring (L, "new"); |
|
|
|
lua_newtable (L); |
|
|
|
luaL_register (L, NULL, rspamd_kann_new_f); |
|
|
|
lua_settable (L, -3); |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
@@ -210,7 +241,9 @@ lua_check_kann_node (lua_State *L, int pos) |
|
|
|
void luaopen_kann (lua_State *L) |
|
|
|
{ |
|
|
|
/* Metatables */ |
|
|
|
rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); |
|
|
|
rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */ |
|
|
|
lua_pop (L, 1); /* No need in metatable... */ |
|
|
|
rspamd_lua_new_class (L, KANN_NETWORK_CLASS, NULL); /* TODO: add methods */ |
|
|
|
lua_pop (L, 1); /* No need in metatable... */ |
|
|
|
rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann); |
|
|
|
lua_settop (L, 0); |
|
|
@@ -224,6 +257,13 @@ void luaopen_kann (lua_State *L) |
|
|
|
rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \ |
|
|
|
} while(0) |
|
|
|
|
|
|
|
#define PUSH_KAN_NETWORK(n) do { \ |
|
|
|
kann_t **pn; \ |
|
|
|
pn = lua_newuserdata (L, sizeof (kann_t *)); \ |
|
|
|
*pn = (n); \ |
|
|
|
rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \ |
|
|
|
} while(0) |
|
|
|
|
|
|
|
#define PROCESS_KAD_FLAGS(n, pos) do { \ |
|
|
|
int fl = 0; \ |
|
|
|
if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \ |
|
|
@@ -561,4 +601,131 @@ lua_kann_loss_ce_multi_weighted (lua_State *L) |
|
|
|
} |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/* Creation functions */ |
|
|
|
static int |
|
|
|
lua_kann_new_scalar (lua_State *L) |
|
|
|
{ |
|
|
|
gint flag = luaL_checkinteger (L, 1); |
|
|
|
double x = luaL_checknumber (L, 2); |
|
|
|
kad_node_t *t; |
|
|
|
|
|
|
|
t = kann_new_scalar (flag, x); |
|
|
|
|
|
|
|
PROCESS_KAD_FLAGS (t, 3); |
|
|
|
PUSH_KAD_NODE (t); |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
static int |
|
|
|
lua_kann_new_weight (lua_State *L) |
|
|
|
{ |
|
|
|
gint nrow = luaL_checkinteger (L, 1); |
|
|
|
gint ncol = luaL_checkinteger (L, 2); |
|
|
|
kad_node_t *t; |
|
|
|
|
|
|
|
t = kann_new_weight (nrow, ncol); |
|
|
|
|
|
|
|
PROCESS_KAD_FLAGS (t, 3); |
|
|
|
PUSH_KAD_NODE (t); |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
static int |
|
|
|
lua_kann_new_bias (lua_State *L) |
|
|
|
{ |
|
|
|
gint n = luaL_checkinteger (L, 1); |
|
|
|
kad_node_t *t; |
|
|
|
|
|
|
|
t = kann_new_bias (n); |
|
|
|
|
|
|
|
PROCESS_KAD_FLAGS (t, 2); |
|
|
|
PUSH_KAD_NODE (t); |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
static int |
|
|
|
lua_kann_new_weight_conv2d (lua_State *L) |
|
|
|
{ |
|
|
|
gint nout = luaL_checkinteger (L, 1); |
|
|
|
gint nin = luaL_checkinteger (L, 2); |
|
|
|
gint krow = luaL_checkinteger (L, 3); |
|
|
|
gint kcol = luaL_checkinteger (L, 4); |
|
|
|
kad_node_t *t; |
|
|
|
|
|
|
|
t = kann_new_weight_conv2d (nout, nin, krow, kcol); |
|
|
|
|
|
|
|
PROCESS_KAD_FLAGS (t, 5); |
|
|
|
PUSH_KAD_NODE (t); |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
static int |
|
|
|
lua_kann_new_weight_conv1d (lua_State *L) |
|
|
|
{ |
|
|
|
gint nout = luaL_checkinteger (L, 1); |
|
|
|
gint nin = luaL_checkinteger (L, 2); |
|
|
|
gint klen = luaL_checkinteger (L, 3); |
|
|
|
kad_node_t *t; |
|
|
|
|
|
|
|
t = kann_new_weight_conv1d (nout, nin, klen); |
|
|
|
|
|
|
|
PROCESS_KAD_FLAGS (t, 4); |
|
|
|
PUSH_KAD_NODE (t); |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
static int |
|
|
|
lua_kann_new_leaf (lua_State *L) |
|
|
|
{ |
|
|
|
gint dim = luaL_checkinteger (L, 1), i, *ar; |
|
|
|
kad_node_t *t; |
|
|
|
|
|
|
|
if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) { |
|
|
|
ar = g_malloc0 (sizeof (ar) * dim); |
|
|
|
|
|
|
|
for (i = 0; i < dim; i ++) { |
|
|
|
lua_rawgeti (L, 2, i + 1); |
|
|
|
ar[i] = lua_tointeger (L, -1); |
|
|
|
lua_pop (L, 1); |
|
|
|
} |
|
|
|
|
|
|
|
t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar); |
|
|
|
|
|
|
|
PROCESS_KAD_FLAGS (t, 3); |
|
|
|
PUSH_KAD_NODE (t); |
|
|
|
|
|
|
|
g_free (ar); |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid arguments for new.leaf, " |
|
|
|
"dim and vector of elements are required"); |
|
|
|
} |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
static int |
|
|
|
lua_kann_new_kann (lua_State *L) |
|
|
|
{ |
|
|
|
kad_node_t *cost = lua_check_kann_node (L, 1); |
|
|
|
kann_t *k; |
|
|
|
|
|
|
|
if (cost) { |
|
|
|
k = kann_new (cost, 0); |
|
|
|
|
|
|
|
PUSH_KAN_NETWORK (k); |
|
|
|
} |
|
|
|
else { |
|
|
|
return luaL_error (L, "invalid arguments for new.kann, " |
|
|
|
"cost node is required"); |
|
|
|
} |
|
|
|
|
|
|
|
return 1; |
|
|
|
} |