aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_kann.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-06-29 12:35:00 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-06-29 12:35:00 +0100
commitc8c4280fee8cc896a98dccb26626b7853a18aa7d (patch)
tree40722874a5b5baf9aba34eec472dd297d69e2d8d /src/lua/lua_kann.c
parent648b027ccac1448659ede3224f66b356adfadd95 (diff)
downloadrspamd-c8c4280fee8cc896a98dccb26626b7853a18aa7d.tar.gz
rspamd-c8c4280fee8cc896a98dccb26626b7853a18aa7d.zip
[Project] Add some missing functions to kann API
Diffstat (limited to 'src/lua/lua_kann.c')
-rw-r--r--src/lua/lua_kann.c171
1 files changed, 169 insertions, 2 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index cec75acbb..171c81454 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -23,6 +23,7 @@
*/
#define KANN_NODE_CLASS "rspamd{kann_node}"
+#define KANN_NETWORK_CLASS "rspamd{kann}"
/* Simple macros to define behaviour */
#define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
@@ -34,6 +35,10 @@
#define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
#define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
+#define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
+#define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}
+
+
/*
* Forwarded declarations
*/
@@ -115,6 +120,26 @@ static luaL_reg rspamd_kann_loss_f[] = {
{NULL, NULL},
};
+/* Creation functions */
+KANN_NEW_DEF (leaf);
+KANN_NEW_DEF (scalar);
+KANN_NEW_DEF (weight);
+KANN_NEW_DEF (bias);
+KANN_NEW_DEF (weight_conv2d);
+KANN_NEW_DEF (weight_conv1d);
+KANN_NEW_DEF (kann);
+
+static luaL_reg rspamd_kann_new_f[] = {
+ KANN_NEW_INTERFACE (leaf),
+ KANN_NEW_INTERFACE (scalar),
+ KANN_NEW_INTERFACE (weight),
+ KANN_NEW_INTERFACE (bias),
+ KANN_NEW_INTERFACE (weight_conv2d),
+ KANN_NEW_INTERFACE (weight_conv1d),
+ KANN_NEW_INTERFACE (kann),
+ {NULL, NULL},
+};
+
static int
rspamd_kann_table_to_flags (lua_State *L, int table_pos)
{
@@ -196,6 +221,12 @@ lua_load_kann (lua_State * L)
luaL_register (L, NULL, rspamd_kann_loss_f);
lua_settable (L, -3);
+ /* Create functions */
+ lua_pushstring (L, "new");
+ lua_newtable (L);
+ luaL_register (L, NULL, rspamd_kann_new_f);
+ lua_settable (L, -3);
+
return 1;
}
@@ -210,7 +241,9 @@ lua_check_kann_node (lua_State *L, int pos)
void luaopen_kann (lua_State *L)
{
/* Metatables */
- rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
+ rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
+ lua_pop (L, 1); /* No need in metatable... */
+ rspamd_lua_new_class (L, KANN_NETWORK_CLASS, NULL); /* TODO: add methods */
lua_pop (L, 1); /* No need in metatable... */
rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
lua_settop (L, 0);
@@ -224,6 +257,13 @@ void luaopen_kann (lua_State *L)
rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
} while(0)
+#define PUSH_KAN_NETWORK(n) do { \
+ kann_t **pn; \
+ pn = lua_newuserdata (L, sizeof (kann_t *)); \
+ *pn = (n); \
+ rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
+} while(0)
+
#define PROCESS_KAD_FLAGS(n, pos) do { \
int fl = 0; \
if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
@@ -561,4 +601,131 @@ lua_kann_loss_ce_multi_weighted (lua_State *L)
}
return 1;
-} \ No newline at end of file
+}
+
+/* Creation functions */
+static int
+lua_kann_new_scalar (lua_State *L)
+{
+ gint flag = luaL_checkinteger (L, 1);
+ double x = luaL_checknumber (L, 2);
+ kad_node_t *t;
+
+ t = kann_new_scalar (flag, x);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight (lua_State *L)
+{
+ gint nrow = luaL_checkinteger (L, 1);
+ gint ncol = luaL_checkinteger (L, 2);
+ kad_node_t *t;
+
+ t = kann_new_weight (nrow, ncol);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_bias (lua_State *L)
+{
+ gint n = luaL_checkinteger (L, 1);
+ kad_node_t *t;
+
+ t = kann_new_bias (n);
+
+ PROCESS_KAD_FLAGS (t, 2);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight_conv2d (lua_State *L)
+{
+ gint nout = luaL_checkinteger (L, 1);
+ gint nin = luaL_checkinteger (L, 2);
+ gint krow = luaL_checkinteger (L, 3);
+ gint kcol = luaL_checkinteger (L, 4);
+ kad_node_t *t;
+
+ t = kann_new_weight_conv2d (nout, nin, krow, kcol);
+
+ PROCESS_KAD_FLAGS (t, 5);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight_conv1d (lua_State *L)
+{
+ gint nout = luaL_checkinteger (L, 1);
+ gint nin = luaL_checkinteger (L, 2);
+ gint klen = luaL_checkinteger (L, 3);
+ kad_node_t *t;
+
+ t = kann_new_weight_conv1d (nout, nin, klen);
+
+ PROCESS_KAD_FLAGS (t, 4);
+ PUSH_KAD_NODE (t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_leaf (lua_State *L)
+{
+ gint dim = luaL_checkinteger (L, 1), i, *ar;
+ kad_node_t *t;
+
+ if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
+ ar = g_malloc0 (sizeof (ar) * dim);
+
+ for (i = 0; i < dim; i ++) {
+ lua_rawgeti (L, 2, i + 1);
+ ar[i] = lua_tointeger (L, -1);
+ lua_pop (L, 1);
+ }
+
+ t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+
+ g_free (ar);
+ }
+ else {
+ return luaL_error (L, "invalid arguments for new.leaf, "
+ "dim and vector of elements are required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_new_kann (lua_State *L)
+{
+ kad_node_t *cost = lua_check_kann_node (L, 1);
+ kann_t *k;
+
+ if (cost) {
+ k = kann_new (cost, 0);
+
+ PUSH_KAN_NETWORK (k);
+ }
+ else {
+ return luaL_error (L, "invalid arguments for new.kann, "
+ "cost node is required");
+ }
+
+ return 1;
+}