From de306aff80b8177cce0c8c6748f56ed443039186 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 10 Oct 2016 17:37:26 +0100 Subject: [PATCH] [Fix] Slightly fix ANN routines --- src/lua/lua_fann.c | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c index 51d42b573..d1f187510 100644 --- a/src/lua/lua_fann.c +++ b/src/lua/lua_fann.c @@ -132,6 +132,8 @@ lua_fann_create (lua_State *L) f = fann_create_standard_array (nlayers, layers); fann_set_activation_function_hidden (f, FANN_SIGMOID_SYMMETRIC); fann_set_activation_function_output (f, FANN_SIGMOID_SYMMETRIC); + fann_set_training_algorithm (f, FANN_TRAIN_INCREMENTAL); + fann_randomize_weights (f, 0, 1); if (f != NULL) { pfann = lua_newuserdata (L, sizeof (gpointer)); @@ -334,7 +336,7 @@ lua_fann_train (lua_State *L) #else struct fann *f = rspamd_lua_check_fann (L, 1); guint ninputs, noutputs, j; - float *cur_input, *cur_output; + fann_type *cur_input, *cur_output; gboolean ret = FALSE; if (f != NULL) { @@ -350,7 +352,7 @@ lua_fann_train (lua_State *L) noutputs, fann_get_num_output (f)); } else { - cur_input = g_malloc (ninputs * sizeof (gint)); + cur_input = g_malloc (ninputs * sizeof (fann_type)); for (j = 0; j < ninputs; j ++) { lua_rawgeti (L, 2, j + 1); @@ -358,7 +360,7 @@ lua_fann_train (lua_State *L) lua_pop (L, 1); } - cur_output = g_malloc (noutputs * sizeof (gint)); + cur_output = g_malloc (noutputs * sizeof (fann_type)); for (j = 0; j < noutputs; j++) { lua_rawgeti (L, 3, j + 1); @@ -396,7 +398,7 @@ lua_fann_test (lua_State *L) #else struct fann *f = rspamd_lua_check_fann (L, 1); guint ninputs, noutputs, i, tbl_idx = 2; - float *cur_input, *cur_output; + fann_type *cur_input, *cur_output; if (f != NULL) { /* First check sanity, call for table.getn for that */ @@ -415,7 +417,7 @@ lua_fann_test (lua_State *L) } } - cur_input = g_slice_alloc (ninputs * sizeof (gint)); + cur_input = g_slice_alloc (ninputs * sizeof (fann_type)); for (i = 0; i < ninputs; i++) { lua_rawgeti (L, tbl_idx, i + 1); @@ -432,7 +434,7 @@ lua_fann_test (lua_State *L) lua_rawseti (L, -2, i + 1); } - g_slice_free1 (ninputs * sizeof (gint), cur_input); + g_slice_free1 (ninputs * sizeof (fann_type), cur_input); } else { lua_pushnil (L); -- 2.39.5