diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-04-06 14:21:39 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-04-06 14:21:39 +0100 |
commit | 04cd8e84c78d3341f050fa5888c90d8edda450b0 (patch) | |
tree | 13e2fa46419fccd54511dc857919e8832381e066 /src/lua/lua_fann.c | |
parent | 2451cbc744963740e08e6c37180bae65ff40e89d (diff) | |
download | rspamd-04cd8e84c78d3341f050fa5888c90d8edda450b0.tar.gz rspamd-04cd8e84c78d3341f050fa5888c90d8edda450b0.zip |
[Fix] Fix fann train
Diffstat (limited to 'src/lua/lua_fann.c')
-rw-r--r-- | src/lua/lua_fann.c | 94 |
1 files changed, 31 insertions, 63 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c index b2bd5ba86..af6d0b590 100644 --- a/src/lua/lua_fann.c +++ b/src/lua/lua_fann.c @@ -175,9 +175,8 @@ lua_fann_load (lua_State *L) * Trains neural network with samples. Inputs and outputs should be tables of * equal size, each row in table should be N inputs and M outputs, e.g. * {0, 1, 1} -> {0} - * {1, 0, 0} -> {1} - * @param {table/table} inputs input samples - * @param {table/table} outputs output samples + * @param {table} inputs input samples + * @param {table} outputs output samples * @return {number} number of samples learned */ static gint @@ -187,79 +186,48 @@ lua_fann_train (lua_State *L) return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); - guint ninputs, noutputs, i, j, cur_len; + guint ninputs, noutputs, j; float *cur_input, *cur_output; - gint ret = 0; + gboolean ret = FALSE; if (f != NULL) { /* First check sanity, call for table.getn for that */ ninputs = rspamd_lua_table_size (L, 2); noutputs = rspamd_lua_table_size (L, 3); - if (ninputs != noutputs) { - msg_err ("bad number of inputs(%d) and output(%d) args for train", - ninputs, noutputs); + if (ninputs != fann_get_num_input (f) || + noutputs != fann_get_num_output (f)) { + msg_err ("bad number of inputs(%d, expected %d) and " + "output(%d, expected %d) args for train", + ninputs, fann_get_num_input (f), + noutputs, fann_get_num_output (f)); } else { - for (i = 0; i < ninputs; i ++) { - /* Push table with inputs */ - lua_rawgeti (L, 2, i + 1); - - cur_len = rspamd_lua_table_size (L, -1); - - if (cur_len != fann_get_num_input (f)) { - msg_err ( - "bad number of input samples: %d, %d expected", - cur_len, - fann_get_num_input (f)); - lua_pop (L, 1); - continue; - } - - cur_input = g_malloc (cur_len * sizeof (gint)); - - for (j = 0; j < cur_len; j ++) { - lua_rawgeti (L, -1, j + 1); - cur_input[i] = lua_tonumber (L, -1); - lua_pop (L, 1); - } - - lua_pop (L, 1); /* Inputs table */ - - /* Push table with outputs */ - lua_rawgeti (L, 3, i + 1); - - cur_len = rspamd_lua_table_size (L, -1); - - if (cur_len != fann_get_num_output (f)) { - msg_err ( - "bad number of output samples: %d, %d expected", - cur_len, - fann_get_num_output (f)); - lua_pop (L, 1); - g_free (cur_input); - continue; - } - - cur_output = g_malloc (cur_len * sizeof (gint)); - - for (j = 0; j < cur_len; j++) { - lua_rawgeti (L, -1, j + 1); - cur_output[i] = lua_tonumber (L, -1); - lua_pop (L, 1); - } - - lua_pop (L, 1); /* Outputs table */ - - fann_train (f, cur_input, cur_output); - g_free (cur_input); - g_free (cur_output); - ret ++; + cur_input = g_malloc (ninputs * sizeof (gint)); + + for (j = 0; j < ninputs; j ++) { + lua_rawgeti (L, 2, j + 1); + cur_input[j] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + cur_output = g_malloc (noutputs * sizeof (gint)); + + for (j = 0; j < noutputs; j++) { + lua_rawgeti (L, 3, j + 1); + cur_output[j] = lua_tonumber (L, -1); + lua_pop (L, 1); } + + fann_train (f, cur_input, cur_output); + g_free (cur_input); + g_free (cur_output); + + ret = TRUE; } } - lua_pushnumber (L, ret); + lua_pushboolean (L, ret); return 1; #endif |