From 120575cd75e899da79b13fadb11e03f7e18f39b3 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 3 Nov 2016 18:45:56 +0000 Subject: [PATCH] [Feature] Add extended version for fann creation function --- src/lua/lua_fann.c | 153 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 152 insertions(+), 1 deletion(-) diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c index 1d02fe3f0..4f4ffbecd 100644 --- a/src/lua/lua_fann.c +++ b/src/lua/lua_fann.c @@ -14,7 +14,7 @@ * limitations under the License. */ #include "lua_common.h" - +#include "util.h" #ifdef WITH_FANN #include #endif @@ -33,6 +33,7 @@ */ LUA_FUNCTION_DEF (fann, is_enabled); LUA_FUNCTION_DEF (fann, create); +LUA_FUNCTION_DEF (fann, create_full); LUA_FUNCTION_DEF (fann, load_file); LUA_FUNCTION_DEF (fann, load_data); @@ -52,6 +53,7 @@ LUA_FUNCTION_DEF (fann, dtor); static const struct luaL_reg fannlib_f[] = { LUA_INTERFACE_DEF (fann, is_enabled), LUA_INTERFACE_DEF (fann, create), + LUA_INTERFACE_DEF (fann, create_full), LUA_INTERFACE_DEF (fann, load_file), {"load", lua_fann_load_file}, LUA_INTERFACE_DEF (fann, load_data), @@ -147,6 +149,8 @@ lua_fann_create (lua_State *L) else { lua_pushnil (L); } + + g_free (layers); } else { lua_pushnil (L); @@ -156,6 +160,153 @@ lua_fann_create (lua_State *L) #endif } +#ifdef WITH_FANN +static enum fann_activationfunc_enum +string_to_activation_func (const gchar *str) +{ + if (str == NULL) { + return FANN_SIGMOID_SYMMETRIC; + } + if (strcmp (str, "sigmoid") == 0) { + return FANN_SIGMOID; + } + else if (strcmp (str, "elliot") == 0) { + return FANN_ELLIOT; + } + else if (strcmp (str, "elliot_symmetric") == 0) { + return FANN_ELLIOT_SYMMETRIC; + } + else if (strcmp (str, "linear") == 0) { + return FANN_LINEAR; + } + + return FANN_SIGMOID_SYMMETRIC; +} + +static enum fann_train_enum +string_to_learn_alg (const gchar *str) +{ + if (str == NULL) { + return FANN_TRAIN_INCREMENTAL; + } + if (strcmp (str, "rprop") == 0) { + return FANN_TRAIN_RPROP; + } + else if (strcmp (str, "qprop") == 0) { + return FANN_TRAIN_QUICKPROP; + } + else if (strcmp (str, "sarprop") == 0) { + return FANN_TRAIN_SARPROP; + } + else if (strcmp (str, "batch") == 0) { + return FANN_TRAIN_BATCH; + } + + return FANN_TRAIN_INCREMENTAL; +} +#endif + +/*** + * @function rspamd_fann.create_full(params) + * Creates new neural network with parameters: + * - `layers` {table/numbers}: table of layers in form: {N1, N2, N3 ... Nn} where N is number of neurons in a layer + * - `activation_hidden` {string}: activation function type for hidden layers (`tanh` by default) + * - `activation_output` {string}: activation function type for output layer (`tanh` by default) + * - `sparsed` {float}: create sparsed ANN, where number is a coefficient for sparsing + * - `learn` {string}: learning algorithm (quickprop, rprop or incremental) + * - `randomize` {boolean}: randomize weights (true by default) + * @return {fann} fann object + */ +static gint +lua_fann_create_full (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f, **pfann; + guint nlayers, *layers, i; + const gchar *activation_hidden = NULL, *activation_output, *learn_alg = NULL; + gdouble sparsed = 0.0; + gboolean randomize_ann = TRUE; + GError *err = NULL; + + if (lua_type (L, 1) == LUA_TTABLE) { + lua_pushstring (L, "layers"); + lua_gettable (L, 1); + + if (lua_type (L, -1) != LUA_TTABLE) { + return luaL_error (L, "bad layers attribute"); + } + + nlayers = rspamd_lua_table_size (L, -1); + if (nlayers < 2) { + return luaL_error (L, "bad layers attribute"); + } + + layers = g_new0 (guint, nlayers); + + for (i = 0; i < nlayers; i ++) { + lua_rawgeti (L, -1, i + 1); + layers[i] = luaL_checknumber (L, -1); + lua_pop (L, 1); + } + + lua_pop (L, 1); /* Table */ + + if (!rspamd_lua_parse_table_arguments (L, 1, &err, + "sparsed=N;randomize=B;learn=S;activation_hidden=S;activation_output=S", + &sparsed, &randomize_ann, &learn_alg, &activation_hidden, &activation_output)) { + g_free (layers); + + if (err) { + gint r; + + r = luaL_error (L, "invalid arguments: %s", err->message); + g_error_free (err); + return r; + } + else { + return luaL_error (L, "invalid arguments"); + } + } + + if (sparsed != 0.0) { + f = fann_create_standard_array (nlayers, layers); + } + else { + f = fann_create_sparse_array (sparsed, nlayers, layers); + } + + if (f != NULL) { + pfann = lua_newuserdata (L, sizeof (gpointer)); + *pfann = f; + rspamd_lua_setclass (L, "rspamd{fann}", -1); + } + else { + g_free (layers); + return luaL_error (L, "cannot create fann"); + } + + fann_set_activation_function_hidden (f, + string_to_activation_func (activation_hidden)); + fann_set_activation_function_output (f, + string_to_activation_func (activation_output)); + fann_set_training_algorithm (f, string_to_learn_alg (learn_alg)); + + if (randomize_ann) { + fann_randomize_weights (f, 0, 1); + } + + g_free (layers); + } + else { + return luaL_error (L, "bad arguments"); + } + + return 1; +#endif +} + /*** * @function rspamd_fann.load(file) * Loads neural network from the file -- 2.39.5