*/
#define KANN_NODE_CLASS "rspamd{kann_node}"
+#define KANN_NETWORK_CLASS "rspamd{kann}"
/* Simple macros to define behaviour */
#define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
#define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
#define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
+#define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
+#define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}
+
+
/*
* Forwarded declarations
*/
{NULL, NULL},
};
+/* Creation functions */
+KANN_NEW_DEF (leaf);
+KANN_NEW_DEF (scalar);
+KANN_NEW_DEF (weight);
+KANN_NEW_DEF (bias);
+KANN_NEW_DEF (weight_conv2d);
+KANN_NEW_DEF (weight_conv1d);
+KANN_NEW_DEF (kann);
+
+static luaL_reg rspamd_kann_new_f[] = {
+ KANN_NEW_INTERFACE (leaf),
+ KANN_NEW_INTERFACE (scalar),
+ KANN_NEW_INTERFACE (weight),
+ KANN_NEW_INTERFACE (bias),
+ KANN_NEW_INTERFACE (weight_conv2d),
+ KANN_NEW_INTERFACE (weight_conv1d),
+ KANN_NEW_INTERFACE (kann),
+ {NULL, NULL},
+};
+
static int
rspamd_kann_table_to_flags (lua_State *L, int table_pos)
{
luaL_register (L, NULL, rspamd_kann_loss_f);
lua_settable (L, -3);
+ /* Create functions */
+ lua_pushstring (L, "new");
+ lua_newtable (L);
+ luaL_register (L, NULL, rspamd_kann_new_f);
+ lua_settable (L, -3);
+
return 1;
}
void luaopen_kann (lua_State *L)
{
/* Metatables */
- rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
+ rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
+ lua_pop (L, 1); /* No need in metatable... */
+ rspamd_lua_new_class (L, KANN_NETWORK_CLASS, NULL); /* TODO: add methods */
lua_pop (L, 1); /* No need in metatable... */
rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
lua_settop (L, 0);
rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
} while(0)
+#define PUSH_KAN_NETWORK(n) do { \
+ kann_t **pn; \
+ pn = lua_newuserdata (L, sizeof (kann_t *)); \
+ *pn = (n); \
+ rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
+} while(0)
+
#define PROCESS_KAD_FLAGS(n, pos) do { \
int fl = 0; \
if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
}
return 1;
-}
\ No newline at end of file
+}
+
+/* Creation functions */
+static int
+lua_kann_new_scalar (lua_State *L)
+{
+ gint flag = luaL_checkinteger (L, 1);
+ double x = luaL_checknumber (L, 2);
+ kad_node_t *t;
+
+ t = kann_new_scalar (flag, x);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight (lua_State *L)
+{
+ gint nrow = luaL_checkinteger (L, 1);
+ gint ncol = luaL_checkinteger (L, 2);
+ kad_node_t *t;
+
+ t = kann_new_weight (nrow, ncol);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_bias (lua_State *L)
+{
+ gint n = luaL_checkinteger (L, 1);
+ kad_node_t *t;
+
+ t = kann_new_bias (n);
+
+ PROCESS_KAD_FLAGS (t, 2);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight_conv2d (lua_State *L)
+{
+ gint nout = luaL_checkinteger (L, 1);
+ gint nin = luaL_checkinteger (L, 2);
+ gint krow = luaL_checkinteger (L, 3);
+ gint kcol = luaL_checkinteger (L, 4);
+ kad_node_t *t;
+
+ t = kann_new_weight_conv2d (nout, nin, krow, kcol);
+
+ PROCESS_KAD_FLAGS (t, 5);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight_conv1d (lua_State *L)
+{
+ gint nout = luaL_checkinteger (L, 1);
+ gint nin = luaL_checkinteger (L, 2);
+ gint klen = luaL_checkinteger (L, 3);
+ kad_node_t *t;
+
+ t = kann_new_weight_conv1d (nout, nin, klen);
+
+ PROCESS_KAD_FLAGS (t, 4);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_leaf (lua_State *L)
+{
+ gint dim = luaL_checkinteger (L, 1), i, *ar;
+ kad_node_t *t;
+
+ if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
+ ar = g_malloc0 (sizeof (ar) * dim);
+
+ for (i = 0; i < dim; i ++) {
+ lua_rawgeti (L, 2, i + 1);
+ ar[i] = lua_tointeger (L, -1);
+ lua_pop (L, 1);
+ }
+
+ t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+
+ g_free (ar);
+ }
+ else {
+ return luaL_error (L, "invalid arguments for new.leaf, "
+ "dim and vector of elements are required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_new_kann (lua_State *L)
+{
+ kad_node_t *cost = lua_check_kann_node (L, 1);
+ kann_t *k;
+
+ if (cost) {
+ k = kann_new (cost, 0);
+
+ PUSH_KAN_NETWORK (k);
+ }
+ else {
+ return luaL_error (L, "invalid arguments for new.kann, "
+ "cost node is required");
+ }
+
+ return 1;
+}