]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Add function to get neural net layers
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 15 Oct 2016 11:24:01 +0000 (12:24 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 15 Oct 2016 12:34:42 +0000 (13:34 +0100)
src/lua/lua_fann.c

index d6dd921e776c8aa55a6c02f469ae90fb678b4f9e..39bcb146f0b5d58c4775c1fee09982fe9cc5e881 100644 (file)
@@ -45,6 +45,7 @@ LUA_FUNCTION_DEF (fann, save);
 LUA_FUNCTION_DEF (fann, data);
 LUA_FUNCTION_DEF (fann, get_inputs);
 LUA_FUNCTION_DEF (fann, get_outputs);
+LUA_FUNCTION_DEF (fann, get_layers);
 LUA_FUNCTION_DEF (fann, get_mse);
 LUA_FUNCTION_DEF (fann, dtor);
 
@@ -64,6 +65,7 @@ static const struct luaL_reg fannlib_m[] = {
                LUA_INTERFACE_DEF (fann, data),
                LUA_INTERFACE_DEF (fann, get_inputs),
                LUA_INTERFACE_DEF (fann, get_outputs),
+               LUA_INTERFACE_DEF (fann, get_layers),
                LUA_INTERFACE_DEF (fann, get_mse),
                {"__gc", lua_fann_dtor},
                {"__tostring", rspamd_lua_class_tostring},
@@ -518,6 +520,41 @@ lua_fann_get_mse (lua_State *L)
 #endif
 }
 
+/***
+ * @method rspamd_fann:get_layers()
+ * Returns array of neurons count for each layer
+ * @return {table/number} table with number ofr neurons in each layer
+ */
+static gint
+lua_fann_get_layers (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+       guint nlayers, i, *layers;
+
+       if (f != NULL) {
+               nlayers = fann_get_num_layers (f);
+               layers = g_new (guint, nlayers);
+               fann_get_layer_array (f, layers);
+               lua_createtable (L, nlayers, 0);
+
+               for (i = 0; i < nlayers; i ++) {
+                       lua_pushnumber (L, layers[i]);
+                       lua_rawseti (L, -1, i + 1);
+               }
+
+               g_free (layers);
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+#endif
+}
+
 /***
  * @method rspamd_fann:save(fname)
  * Save fann to file named 'fname'