summaryrefslogtreecommitdiffstats
path: root/src/lua/lua_fann.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-04-06 14:21:39 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-04-06 14:21:39 +0100
commit04cd8e84c78d3341f050fa5888c90d8edda450b0 (patch)
tree13e2fa46419fccd54511dc857919e8832381e066 /src/lua/lua_fann.c
parent2451cbc744963740e08e6c37180bae65ff40e89d (diff)
downloadrspamd-04cd8e84c78d3341f050fa5888c90d8edda450b0.tar.gz
rspamd-04cd8e84c78d3341f050fa5888c90d8edda450b0.zip
[Fix] Fix fann train
Diffstat (limited to 'src/lua/lua_fann.c')
-rw-r--r--src/lua/lua_fann.c94
1 files changed, 31 insertions, 63 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c
index b2bd5ba86..af6d0b590 100644
--- a/src/lua/lua_fann.c
+++ b/src/lua/lua_fann.c
@@ -175,9 +175,8 @@ lua_fann_load (lua_State *L)
* Trains neural network with samples. Inputs and outputs should be tables of
* equal size, each row in table should be N inputs and M outputs, e.g.
* {0, 1, 1} -> {0}
- * {1, 0, 0} -> {1}
- * @param {table/table} inputs input samples
- * @param {table/table} outputs output samples
+ * @param {table} inputs input samples
+ * @param {table} outputs output samples
* @return {number} number of samples learned
*/
static gint
@@ -187,79 +186,48 @@ lua_fann_train (lua_State *L)
return 0;
#else
struct fann *f = rspamd_lua_check_fann (L, 1);
- guint ninputs, noutputs, i, j, cur_len;
+ guint ninputs, noutputs, j;
float *cur_input, *cur_output;
- gint ret = 0;
+ gboolean ret = FALSE;
if (f != NULL) {
/* First check sanity, call for table.getn for that */
ninputs = rspamd_lua_table_size (L, 2);
noutputs = rspamd_lua_table_size (L, 3);
- if (ninputs != noutputs) {
- msg_err ("bad number of inputs(%d) and output(%d) args for train",
- ninputs, noutputs);
+ if (ninputs != fann_get_num_input (f) ||
+ noutputs != fann_get_num_output (f)) {
+ msg_err ("bad number of inputs(%d, expected %d) and "
+ "output(%d, expected %d) args for train",
+ ninputs, fann_get_num_input (f),
+ noutputs, fann_get_num_output (f));
}
else {
- for (i = 0; i < ninputs; i ++) {
- /* Push table with inputs */
- lua_rawgeti (L, 2, i + 1);
-
- cur_len = rspamd_lua_table_size (L, -1);
-
- if (cur_len != fann_get_num_input (f)) {
- msg_err (
- "bad number of input samples: %d, %d expected",
- cur_len,
- fann_get_num_input (f));
- lua_pop (L, 1);
- continue;
- }
-
- cur_input = g_malloc (cur_len * sizeof (gint));
-
- for (j = 0; j < cur_len; j ++) {
- lua_rawgeti (L, -1, j + 1);
- cur_input[i] = lua_tonumber (L, -1);
- lua_pop (L, 1);
- }
-
- lua_pop (L, 1); /* Inputs table */
-
- /* Push table with outputs */
- lua_rawgeti (L, 3, i + 1);
-
- cur_len = rspamd_lua_table_size (L, -1);
-
- if (cur_len != fann_get_num_output (f)) {
- msg_err (
- "bad number of output samples: %d, %d expected",
- cur_len,
- fann_get_num_output (f));
- lua_pop (L, 1);
- g_free (cur_input);
- continue;
- }
-
- cur_output = g_malloc (cur_len * sizeof (gint));
-
- for (j = 0; j < cur_len; j++) {
- lua_rawgeti (L, -1, j + 1);
- cur_output[i] = lua_tonumber (L, -1);
- lua_pop (L, 1);
- }
-
- lua_pop (L, 1); /* Outputs table */
-
- fann_train (f, cur_input, cur_output);
- g_free (cur_input);
- g_free (cur_output);
- ret ++;
+ cur_input = g_malloc (ninputs * sizeof (gint));
+
+ for (j = 0; j < ninputs; j ++) {
+ lua_rawgeti (L, 2, j + 1);
+ cur_input[j] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ cur_output = g_malloc (noutputs * sizeof (gint));
+
+ for (j = 0; j < noutputs; j++) {
+ lua_rawgeti (L, 3, j + 1);
+ cur_output[j] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
}
+
+ fann_train (f, cur_input, cur_output);
+ g_free (cur_input);
+ g_free (cur_output);
+
+ ret = TRUE;
}
}
- lua_pushnumber (L, ret);
+ lua_pushboolean (L, ret);
return 1;
#endif