]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Fix fann train
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Apr 2016 13:21:39 +0000 (14:21 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Apr 2016 13:21:39 +0000 (14:21 +0100)
src/lua/lua_fann.c

index b2bd5ba86d123c8f04e18e2dcf0bb626dfe71dd4..af6d0b5903a309c5f361af6025f1ad446e173610 100644 (file)
@@ -175,9 +175,8 @@ lua_fann_load (lua_State *L)
  * Trains neural network with samples. Inputs and outputs should be tables of
  * equal size, each row in table should be N inputs and M outputs, e.g.
  *     {0, 1, 1} -> {0}
- *     {1, 0, 0} -> {1}
- * @param {table/table} inputs input samples
- * @param {table/table} outputs output samples
+ * @param {table} inputs input samples
+ * @param {table} outputs output samples
  * @return {number} number of samples learned
  */
 static gint
@@ -187,79 +186,48 @@ lua_fann_train (lua_State *L)
        return 0;
 #else
        struct fann *f = rspamd_lua_check_fann (L, 1);
-       guint ninputs, noutputs, i, j, cur_len;
+       guint ninputs, noutputs, j;
        float *cur_input, *cur_output;
-       gint ret = 0;
+       gboolean ret = FALSE;
 
        if (f != NULL) {
                /* First check sanity, call for table.getn for that */
                ninputs = rspamd_lua_table_size (L, 2);
                noutputs = rspamd_lua_table_size (L, 3);
 
-               if (ninputs != noutputs) {
-                       msg_err ("bad number of inputs(%d) and output(%d) args for train",
-                                       ninputs, noutputs);
+               if (ninputs != fann_get_num_input (f) ||
+                       noutputs != fann_get_num_output (f)) {
+                       msg_err ("bad number of inputs(%d, expected %d) and "
+                                       "output(%d, expected %d) args for train",
+                                       ninputs, fann_get_num_input (f),
+                                       noutputs, fann_get_num_output (f));
                }
                else {
-                       for (i = 0; i < ninputs; i ++) {
-                               /* Push table with inputs */
-                               lua_rawgeti (L, 2, i + 1);
-
-                               cur_len = rspamd_lua_table_size (L, -1);
-
-                               if (cur_len != fann_get_num_input (f)) {
-                                       msg_err (
-                                                       "bad number of input samples: %d, %d expected",
-                                                       cur_len,
-                                                       fann_get_num_input (f));
-                                       lua_pop (L, 1);
-                                       continue;
-                               }
-
-                               cur_input = g_malloc (cur_len * sizeof (gint));
-
-                               for (j = 0; j < cur_len; j ++) {
-                                       lua_rawgeti (L, -1, j + 1);
-                                       cur_input[i] = lua_tonumber (L, -1);
-                                       lua_pop (L, 1);
-                               }
-
-                               lua_pop (L, 1); /* Inputs table */
-
-                               /* Push table with outputs */
-                               lua_rawgeti (L, 3, i + 1);
-
-                               cur_len = rspamd_lua_table_size (L, -1);
-
-                               if (cur_len != fann_get_num_output (f)) {
-                                       msg_err (
-                                                       "bad number of output samples: %d, %d expected",
-                                                       cur_len,
-                                                       fann_get_num_output (f));
-                                       lua_pop (L, 1);
-                                       g_free (cur_input);
-                                       continue;
-                               }
-
-                               cur_output = g_malloc (cur_len * sizeof (gint));
-
-                               for (j = 0; j < cur_len; j++) {
-                                       lua_rawgeti (L, -1, j + 1);
-                                       cur_output[i] = lua_tonumber (L, -1);
-                                       lua_pop (L, 1);
-                               }
-
-                               lua_pop (L, 1); /* Outputs table */
-
-                               fann_train (f, cur_input, cur_output);
-                               g_free (cur_input);
-                               g_free (cur_output);
-                               ret ++;
+                       cur_input = g_malloc (ninputs * sizeof (gint));
+
+                       for (j = 0; j < ninputs; j ++) {
+                               lua_rawgeti (L, 2, j + 1);
+                               cur_input[j] = lua_tonumber (L, -1);
+                               lua_pop (L, 1);
+                       }
+
+                       cur_output = g_malloc (noutputs * sizeof (gint));
+
+                       for (j = 0; j < noutputs; j++) {
+                               lua_rawgeti (L, 3, j + 1);
+                               cur_output[j] = lua_tonumber (L, -1);
+                               lua_pop (L, 1);
                        }
+
+                       fann_train (f, cur_input, cur_output);
+                       g_free (cur_input);
+                       g_free (cur_output);
+
+                       ret = TRUE;
                }
        }
 
-       lua_pushnumber (L, ret);
+       lua_pushboolean (L, ret);
 
        return 1;
 #endif