From: Vsevolod Stakhov Date: Wed, 6 Apr 2016 13:21:39 +0000 (+0100) Subject: [Fix] Fix fann train X-Git-Tag: 1.2.3~44 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=04cd8e84c78d3341f050fa5888c90d8edda450b0;p=rspamd.git [Fix] Fix fann train --- diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c index b2bd5ba86..af6d0b590 100644 --- a/src/lua/lua_fann.c +++ b/src/lua/lua_fann.c @@ -175,9 +175,8 @@ lua_fann_load (lua_State *L) * Trains neural network with samples. Inputs and outputs should be tables of * equal size, each row in table should be N inputs and M outputs, e.g. * {0, 1, 1} -> {0} - * {1, 0, 0} -> {1} - * @param {table/table} inputs input samples - * @param {table/table} outputs output samples + * @param {table} inputs input samples + * @param {table} outputs output samples * @return {number} number of samples learned */ static gint @@ -187,79 +186,48 @@ lua_fann_train (lua_State *L) return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); - guint ninputs, noutputs, i, j, cur_len; + guint ninputs, noutputs, j; float *cur_input, *cur_output; - gint ret = 0; + gboolean ret = FALSE; if (f != NULL) { /* First check sanity, call for table.getn for that */ ninputs = rspamd_lua_table_size (L, 2); noutputs = rspamd_lua_table_size (L, 3); - if (ninputs != noutputs) { - msg_err ("bad number of inputs(%d) and output(%d) args for train", - ninputs, noutputs); + if (ninputs != fann_get_num_input (f) || + noutputs != fann_get_num_output (f)) { + msg_err ("bad number of inputs(%d, expected %d) and " + "output(%d, expected %d) args for train", + ninputs, fann_get_num_input (f), + noutputs, fann_get_num_output (f)); } else { - for (i = 0; i < ninputs; i ++) { - /* Push table with inputs */ - lua_rawgeti (L, 2, i + 1); - - cur_len = rspamd_lua_table_size (L, -1); - - if (cur_len != fann_get_num_input (f)) { - msg_err ( - "bad number of input samples: %d, %d expected", - cur_len, - fann_get_num_input (f)); - lua_pop (L, 1); - continue; - } - - cur_input = g_malloc (cur_len * sizeof (gint)); - - for (j = 0; j < cur_len; j ++) { - lua_rawgeti (L, -1, j + 1); - cur_input[i] = lua_tonumber (L, -1); - lua_pop (L, 1); - } - - lua_pop (L, 1); /* Inputs table */ - - /* Push table with outputs */ - lua_rawgeti (L, 3, i + 1); - - cur_len = rspamd_lua_table_size (L, -1); - - if (cur_len != fann_get_num_output (f)) { - msg_err ( - "bad number of output samples: %d, %d expected", - cur_len, - fann_get_num_output (f)); - lua_pop (L, 1); - g_free (cur_input); - continue; - } - - cur_output = g_malloc (cur_len * sizeof (gint)); - - for (j = 0; j < cur_len; j++) { - lua_rawgeti (L, -1, j + 1); - cur_output[i] = lua_tonumber (L, -1); - lua_pop (L, 1); - } - - lua_pop (L, 1); /* Outputs table */ - - fann_train (f, cur_input, cur_output); - g_free (cur_input); - g_free (cur_output); - ret ++; + cur_input = g_malloc (ninputs * sizeof (gint)); + + for (j = 0; j < ninputs; j ++) { + lua_rawgeti (L, 2, j + 1); + cur_input[j] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + cur_output = g_malloc (noutputs * sizeof (gint)); + + for (j = 0; j < noutputs; j++) { + lua_rawgeti (L, 3, j + 1); + cur_output[j] = lua_tonumber (L, -1); + lua_pop (L, 1); } + + fann_train (f, cur_input, cur_output); + g_free (cur_input); + g_free (cur_output); + + ret = TRUE; } } - lua_pushnumber (L, ret); + lua_pushboolean (L, ret); return 1; #endif