]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add extended version for fann creation function
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 3 Nov 2016 18:45:56 +0000 (18:45 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 4 Nov 2016 14:13:07 +0000 (14:13 +0000)
src/lua/lua_fann.c

index 1d02fe3f0a5596ac7e2d75dc34e1a72e571e0b8e..4f4ffbecd1f18802bb67d3fdc13f3d887a1f5e5f 100644 (file)
@@ -14,7 +14,7 @@
  * limitations under the License.
  */
 #include "lua_common.h"
-
+#include "util.h"
 #ifdef WITH_FANN
 #include <fann.h>
 #endif
@@ -33,6 +33,7 @@
  */
 LUA_FUNCTION_DEF (fann, is_enabled);
 LUA_FUNCTION_DEF (fann, create);
+LUA_FUNCTION_DEF (fann, create_full);
 LUA_FUNCTION_DEF (fann, load_file);
 LUA_FUNCTION_DEF (fann, load_data);
 
@@ -52,6 +53,7 @@ LUA_FUNCTION_DEF (fann, dtor);
 static const struct luaL_reg fannlib_f[] = {
                LUA_INTERFACE_DEF (fann, is_enabled),
                LUA_INTERFACE_DEF (fann, create),
+               LUA_INTERFACE_DEF (fann, create_full),
                LUA_INTERFACE_DEF (fann, load_file),
                {"load", lua_fann_load_file},
                LUA_INTERFACE_DEF (fann, load_data),
@@ -147,6 +149,8 @@ lua_fann_create (lua_State *L)
                else {
                        lua_pushnil (L);
                }
+
+               g_free (layers);
        }
        else {
                lua_pushnil (L);
@@ -156,6 +160,153 @@ lua_fann_create (lua_State *L)
 #endif
 }
 
+#ifdef WITH_FANN
+static enum fann_activationfunc_enum
+string_to_activation_func (const gchar *str)
+{
+       if (str == NULL) {
+               return FANN_SIGMOID_SYMMETRIC;
+       }
+       if (strcmp (str, "sigmoid") == 0) {
+               return FANN_SIGMOID;
+       }
+       else if (strcmp (str, "elliot") == 0) {
+               return FANN_ELLIOT;
+       }
+       else if (strcmp (str, "elliot_symmetric") == 0) {
+               return FANN_ELLIOT_SYMMETRIC;
+       }
+       else if (strcmp (str, "linear") == 0) {
+               return FANN_LINEAR;
+       }
+
+       return FANN_SIGMOID_SYMMETRIC;
+}
+
+static enum fann_train_enum
+string_to_learn_alg (const gchar *str)
+{
+       if (str == NULL) {
+               return FANN_TRAIN_INCREMENTAL;
+       }
+       if (strcmp (str, "rprop") == 0) {
+               return FANN_TRAIN_RPROP;
+       }
+       else if (strcmp (str, "qprop") == 0) {
+               return FANN_TRAIN_QUICKPROP;
+       }
+       else if (strcmp (str, "sarprop") == 0) {
+               return FANN_TRAIN_SARPROP;
+       }
+       else if (strcmp (str, "batch") == 0) {
+               return FANN_TRAIN_BATCH;
+       }
+
+       return FANN_TRAIN_INCREMENTAL;
+}
+#endif
+
+/***
+ * @function rspamd_fann.create_full(params)
+ * Creates new neural network with parameters:
+ * - `layers` {table/numbers}: table of layers in form: {N1, N2, N3 ... Nn} where N is number of neurons in a layer
+ * - `activation_hidden` {string}: activation function type for hidden layers (`tanh` by default)
+ * - `activation_output` {string}: activation function type for output layer (`tanh` by default)
+ * - `sparsed` {float}: create sparsed ANN, where number is a coefficient for sparsing
+ * - `learn` {string}: learning algorithm (quickprop, rprop or incremental)
+ * - `randomize` {boolean}: randomize weights (true by default)
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_create_full (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f, **pfann;
+       guint nlayers, *layers, i;
+       const gchar *activation_hidden = NULL, *activation_output, *learn_alg = NULL;
+       gdouble sparsed = 0.0;
+       gboolean randomize_ann = TRUE;
+       GError *err = NULL;
+
+       if (lua_type (L, 1) == LUA_TTABLE) {
+               lua_pushstring (L, "layers");
+               lua_gettable (L, 1);
+
+               if (lua_type (L, -1) != LUA_TTABLE) {
+                       return luaL_error (L, "bad layers attribute");
+               }
+
+               nlayers = rspamd_lua_table_size (L, -1);
+               if (nlayers < 2) {
+                       return luaL_error (L, "bad layers attribute");
+               }
+
+               layers = g_new0 (guint, nlayers);
+
+               for (i = 0; i < nlayers; i ++) {
+                       lua_rawgeti (L, -1, i + 1);
+                       layers[i] = luaL_checknumber (L, -1);
+                       lua_pop (L, 1);
+               }
+
+               lua_pop (L, 1); /* Table */
+
+               if (!rspamd_lua_parse_table_arguments (L, 1, &err,
+                               "sparsed=N;randomize=B;learn=S;activation_hidden=S;activation_output=S",
+                               &sparsed, &randomize_ann, &learn_alg, &activation_hidden, &activation_output)) {
+                       g_free (layers);
+
+                       if (err) {
+                               gint r;
+
+                               r = luaL_error (L, "invalid arguments: %s", err->message);
+                               g_error_free (err);
+                               return r;
+                       }
+                       else {
+                               return luaL_error (L, "invalid arguments");
+                       }
+               }
+
+               if (sparsed != 0.0) {
+                       f = fann_create_standard_array (nlayers, layers);
+               }
+               else {
+                       f = fann_create_sparse_array (sparsed, nlayers, layers);
+               }
+
+               if (f != NULL) {
+                       pfann = lua_newuserdata (L, sizeof (gpointer));
+                       *pfann = f;
+                       rspamd_lua_setclass (L, "rspamd{fann}", -1);
+               }
+               else {
+                       g_free (layers);
+                       return luaL_error (L, "cannot create fann");
+               }
+
+               fann_set_activation_function_hidden (f,
+                               string_to_activation_func (activation_hidden));
+               fann_set_activation_function_output (f,
+                               string_to_activation_func (activation_output));
+               fann_set_training_algorithm (f, string_to_learn_alg (learn_alg));
+
+               if (randomize_ann) {
+                       fann_randomize_weights (f, 0, 1);
+               }
+
+               g_free (layers);
+       }
+       else {
+               return luaL_error (L, "bad arguments");
+       }
+
+       return 1;
+#endif
+}
+
 /***
  * @function rspamd_fann.load(file)
  * Loads neural network from the file