aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_fann.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-03 18:45:56 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-04 14:13:07 +0000
commit120575cd75e899da79b13fadb11e03f7e18f39b3 (patch)
treeac0bae55796c235cbe420b2d28d36084c3ffa20e /src/lua/lua_fann.c
parentb096690fbadc95e7d0c7a48d422cdbd0268e2c62 (diff)
downloadrspamd-120575cd75e899da79b13fadb11e03f7e18f39b3.tar.gz
rspamd-120575cd75e899da79b13fadb11e03f7e18f39b3.zip
[Feature] Add extended version for fann creation function
Diffstat (limited to 'src/lua/lua_fann.c')
-rw-r--r--src/lua/lua_fann.c153
1 files changed, 152 insertions, 1 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c
index 1d02fe3f0..4f4ffbecd 100644
--- a/src/lua/lua_fann.c
+++ b/src/lua/lua_fann.c
@@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "lua_common.h"
-
+#include "util.h"
#ifdef WITH_FANN
#include <fann.h>
#endif
@@ -33,6 +33,7 @@
*/
LUA_FUNCTION_DEF (fann, is_enabled);
LUA_FUNCTION_DEF (fann, create);
+LUA_FUNCTION_DEF (fann, create_full);
LUA_FUNCTION_DEF (fann, load_file);
LUA_FUNCTION_DEF (fann, load_data);
@@ -52,6 +53,7 @@ LUA_FUNCTION_DEF (fann, dtor);
static const struct luaL_reg fannlib_f[] = {
LUA_INTERFACE_DEF (fann, is_enabled),
LUA_INTERFACE_DEF (fann, create),
+ LUA_INTERFACE_DEF (fann, create_full),
LUA_INTERFACE_DEF (fann, load_file),
{"load", lua_fann_load_file},
LUA_INTERFACE_DEF (fann, load_data),
@@ -147,6 +149,8 @@ lua_fann_create (lua_State *L)
else {
lua_pushnil (L);
}
+
+ g_free (layers);
}
else {
lua_pushnil (L);
@@ -156,6 +160,153 @@ lua_fann_create (lua_State *L)
#endif
}
+#ifdef WITH_FANN
+static enum fann_activationfunc_enum
+string_to_activation_func (const gchar *str)
+{
+ if (str == NULL) {
+ return FANN_SIGMOID_SYMMETRIC;
+ }
+ if (strcmp (str, "sigmoid") == 0) {
+ return FANN_SIGMOID;
+ }
+ else if (strcmp (str, "elliot") == 0) {
+ return FANN_ELLIOT;
+ }
+ else if (strcmp (str, "elliot_symmetric") == 0) {
+ return FANN_ELLIOT_SYMMETRIC;
+ }
+ else if (strcmp (str, "linear") == 0) {
+ return FANN_LINEAR;
+ }
+
+ return FANN_SIGMOID_SYMMETRIC;
+}
+
+static enum fann_train_enum
+string_to_learn_alg (const gchar *str)
+{
+ if (str == NULL) {
+ return FANN_TRAIN_INCREMENTAL;
+ }
+ if (strcmp (str, "rprop") == 0) {
+ return FANN_TRAIN_RPROP;
+ }
+ else if (strcmp (str, "qprop") == 0) {
+ return FANN_TRAIN_QUICKPROP;
+ }
+ else if (strcmp (str, "sarprop") == 0) {
+ return FANN_TRAIN_SARPROP;
+ }
+ else if (strcmp (str, "batch") == 0) {
+ return FANN_TRAIN_BATCH;
+ }
+
+ return FANN_TRAIN_INCREMENTAL;
+}
+#endif
+
+/***
+ * @function rspamd_fann.create_full(params)
+ * Creates new neural network with parameters:
+ * - `layers` {table/numbers}: table of layers in form: {N1, N2, N3 ... Nn} where N is number of neurons in a layer
+ * - `activation_hidden` {string}: activation function type for hidden layers (`tanh` by default)
+ * - `activation_output` {string}: activation function type for output layer (`tanh` by default)
+ * - `sparsed` {float}: create sparsed ANN, where number is a coefficient for sparsing
+ * - `learn` {string}: learning algorithm (quickprop, rprop or incremental)
+ * - `randomize` {boolean}: randomize weights (true by default)
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_create_full (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f, **pfann;
+ guint nlayers, *layers, i;
+ const gchar *activation_hidden = NULL, *activation_output, *learn_alg = NULL;
+ gdouble sparsed = 0.0;
+ gboolean randomize_ann = TRUE;
+ GError *err = NULL;
+
+ if (lua_type (L, 1) == LUA_TTABLE) {
+ lua_pushstring (L, "layers");
+ lua_gettable (L, 1);
+
+ if (lua_type (L, -1) != LUA_TTABLE) {
+ return luaL_error (L, "bad layers attribute");
+ }
+
+ nlayers = rspamd_lua_table_size (L, -1);
+ if (nlayers < 2) {
+ return luaL_error (L, "bad layers attribute");
+ }
+
+ layers = g_new0 (guint, nlayers);
+
+ for (i = 0; i < nlayers; i ++) {
+ lua_rawgeti (L, -1, i + 1);
+ layers[i] = luaL_checknumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1); /* Table */
+
+ if (!rspamd_lua_parse_table_arguments (L, 1, &err,
+ "sparsed=N;randomize=B;learn=S;activation_hidden=S;activation_output=S",
+ &sparsed, &randomize_ann, &learn_alg, &activation_hidden, &activation_output)) {
+ g_free (layers);
+
+ if (err) {
+ gint r;
+
+ r = luaL_error (L, "invalid arguments: %s", err->message);
+ g_error_free (err);
+ return r;
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+ }
+
+ if (sparsed != 0.0) {
+ f = fann_create_standard_array (nlayers, layers);
+ }
+ else {
+ f = fann_create_sparse_array (sparsed, nlayers, layers);
+ }
+
+ if (f != NULL) {
+ pfann = lua_newuserdata (L, sizeof (gpointer));
+ *pfann = f;
+ rspamd_lua_setclass (L, "rspamd{fann}", -1);
+ }
+ else {
+ g_free (layers);
+ return luaL_error (L, "cannot create fann");
+ }
+
+ fann_set_activation_function_hidden (f,
+ string_to_activation_func (activation_hidden));
+ fann_set_activation_function_output (f,
+ string_to_activation_func (activation_output));
+ fann_set_training_algorithm (f, string_to_learn_alg (learn_alg));
+
+ if (randomize_ann) {
+ fann_randomize_weights (f, 0, 1);
+ }
+
+ g_free (layers);
+ }
+ else {
+ return luaL_error (L, "bad arguments");
+ }
+
+ return 1;
+#endif
+}
+
/***
* @function rspamd_fann.load(file)
* Loads neural network from the file