]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Slightly fix ANN routines
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 10 Oct 2016 16:37:26 +0000 (17:37 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 10 Oct 2016 16:37:26 +0000 (17:37 +0100)
src/lua/lua_fann.c

index 51d42b573af987107d6fbd7a0654f7cfe4a79428..d1f187510857fbe5ee83917d4bd1a68646fa48ac 100644 (file)
@@ -132,6 +132,8 @@ lua_fann_create (lua_State *L)
                f = fann_create_standard_array (nlayers, layers);
                fann_set_activation_function_hidden (f, FANN_SIGMOID_SYMMETRIC);
                fann_set_activation_function_output (f, FANN_SIGMOID_SYMMETRIC);
+               fann_set_training_algorithm (f, FANN_TRAIN_INCREMENTAL);
+               fann_randomize_weights (f, 0, 1);
 
                if (f != NULL) {
                        pfann = lua_newuserdata (L, sizeof (gpointer));
@@ -334,7 +336,7 @@ lua_fann_train (lua_State *L)
 #else
        struct fann *f = rspamd_lua_check_fann (L, 1);
        guint ninputs, noutputs, j;
-       float *cur_input, *cur_output;
+       fann_type *cur_input, *cur_output;
        gboolean ret = FALSE;
 
        if (f != NULL) {
@@ -350,7 +352,7 @@ lua_fann_train (lua_State *L)
                                        noutputs, fann_get_num_output (f));
                }
                else {
-                       cur_input = g_malloc (ninputs * sizeof (gint));
+                       cur_input = g_malloc (ninputs * sizeof (fann_type));
 
                        for (j = 0; j < ninputs; j ++) {
                                lua_rawgeti (L, 2, j + 1);
@@ -358,7 +360,7 @@ lua_fann_train (lua_State *L)
                                lua_pop (L, 1);
                        }
 
-                       cur_output = g_malloc (noutputs * sizeof (gint));
+                       cur_output = g_malloc (noutputs * sizeof (fann_type));
 
                        for (j = 0; j < noutputs; j++) {
                                lua_rawgeti (L, 3, j + 1);
@@ -396,7 +398,7 @@ lua_fann_test (lua_State *L)
 #else
        struct fann *f = rspamd_lua_check_fann (L, 1);
        guint ninputs, noutputs, i, tbl_idx = 2;
-       float *cur_input, *cur_output;
+       fann_type *cur_input, *cur_output;
 
        if (f != NULL) {
                /* First check sanity, call for table.getn for that */
@@ -415,7 +417,7 @@ lua_fann_test (lua_State *L)
                        }
                }
 
-               cur_input = g_slice_alloc (ninputs * sizeof (gint));
+               cur_input = g_slice_alloc (ninputs * sizeof (fann_type));
 
                for (i = 0; i < ninputs; i++) {
                        lua_rawgeti (L, tbl_idx, i + 1);
@@ -432,7 +434,7 @@ lua_fann_test (lua_State *L)
                        lua_rawseti (L, -2, i + 1);
                }
 
-               g_slice_free1 (ninputs * sizeof (gint), cur_input);
+               g_slice_free1 (ninputs * sizeof (fann_type), cur_input);
        }
        else {
                lua_pushnil (L);