]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Implement FANN threaded learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 4 Nov 2016 14:12:42 +0000 (14:12 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 4 Nov 2016 14:13:07 +0000 (14:13 +0000)
src/libutil/util.c
src/lua/lua_fann.c

index f785f05754ccff7ec551e2b1f53d1bc26d0f010e..2538fb57ad8f8ea541d24c82130bd03e10f38990 100644 (file)
@@ -1662,7 +1662,7 @@ rspamd_create_thread (const gchar *name,
        td->data = data;
 
        rspamd_snprintf (td->name, r + sizeof ("4294967296"), "%s-%d", name, id);
-#if ((GLIB_MAJOR_VERSION == 2) && (GLIB_MINOR_VERSION > 30))
+#if ((GLIB_MAJOR_VERSION == 2) && (GLIB_MINOR_VERSION > 32))
        new = g_thread_try_new (td->name, rspamd_thread_func, td, err);
 #else
        new = g_thread_create (rspamd_thread_func, td, TRUE, err);
index 4f4ffbecd1f18802bb67d3fdc13f3d887a1f5e5f..fa1a27ae86e25ee88975ee204fdc89d8aa1260e5 100644 (file)
@@ -41,6 +41,7 @@ LUA_FUNCTION_DEF (fann, load_data);
  * Fann methods
  */
 LUA_FUNCTION_DEF (fann, train);
+LUA_FUNCTION_DEF (fann, train_threaded);
 LUA_FUNCTION_DEF (fann, test);
 LUA_FUNCTION_DEF (fann, save);
 LUA_FUNCTION_DEF (fann, data);
@@ -62,6 +63,7 @@ static const struct luaL_reg fannlib_f[] = {
 
 static const struct luaL_reg fannlib_m[] = {
                LUA_INTERFACE_DEF (fann, train),
+               LUA_INTERFACE_DEF (fann, train_threaded),
                LUA_INTERFACE_DEF (fann, test),
                LUA_INTERFACE_DEF (fann, save),
                LUA_INTERFACE_DEF (fann, data),
@@ -537,6 +539,224 @@ lua_fann_train (lua_State *L)
 #endif
 }
 
+#ifdef WITH_FANN
+struct lua_fann_train_cbdata {
+       lua_State *L;
+       gint pair[2];
+       struct fann_train_data *train;
+       struct fann *f;
+       gint cbref;
+       gdouble desired_mse;
+       guint max_epochs;
+       GThread *t;
+       struct event io;
+};
+
+struct lua_fann_train_reply {
+       gint errcode;
+       float mse;
+       gchar errmsg[128];
+};
+
+static void
+lua_fann_push_train_result (struct lua_fann_train_cbdata *cbdata,
+               gint errcode, float mse, const gchar *errmsg)
+{
+       lua_rawgeti (cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref);
+       lua_pushnumber (cbdata->L, errcode);
+       lua_pushstring (cbdata->L, errmsg);
+       lua_pushnumber (cbdata->L, mse);
+
+       if (lua_pcall (cbdata->L, 3, 0, 0) != 0) {
+               msg_err ("call to train callback failed: %s", lua_tostring (cbdata->L, -1));
+               lua_pop (cbdata->L, 1);
+       }
+}
+
+static void
+lua_fann_thread_notify (gint fd, short what, gpointer ud)
+{
+       struct lua_fann_train_cbdata *cbdata = ud;
+       struct lua_fann_train_reply rep;
+
+       if (read (cbdata->pair[0], &rep, sizeof (rep)) == -1) {
+               if (errno == EAGAIN || errno == EINTR) {
+                       event_add (&cbdata->io, NULL);
+                       return;
+               }
+
+               lua_fann_push_train_result (cbdata, errno, 0.0, strerror (errno));
+       }
+       else {
+               lua_fann_push_train_result (cbdata, rep.errcode, rep.mse, rep.errmsg);
+       }
+
+       write (cbdata->pair[0], "", 1);
+       g_thread_join (cbdata->t);
+       close (cbdata->pair[0]);
+       close (cbdata->pair[1]);
+
+       fann_destroy_train (cbdata->train);
+       luaL_unref (cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref);
+       g_slice_free1 (sizeof (*cbdata), cbdata);
+}
+
+static void *
+lua_fann_train_thread (void *ud)
+{
+       struct lua_fann_train_cbdata *cbdata = ud;
+       struct lua_fann_train_reply rep;
+       gchar repbuf[1];
+
+       msg_info ("start learning ANN, %d epochs are possible",
+                       cbdata->max_epochs);
+       rspamd_socket_blocking (cbdata->pair[1]);
+       fann_train_on_data (cbdata->f, cbdata->train, cbdata->max_epochs, 0,
+                       cbdata->desired_mse);
+       rep.errcode = 0;
+       rspamd_strlcpy (rep.errmsg, "OK", sizeof (rep.errmsg));
+       rep.mse = fann_get_MSE (cbdata->f);
+
+       if (write (cbdata->pair[1], &rep, sizeof (rep)) == -1) {
+               msg_err ("cannot write to socketpair: %s", strerror (errno));
+
+               return NULL;
+       }
+
+       if (read (cbdata->pair[1], repbuf, sizeof (repbuf)) == -1) {
+               msg_err ("cannot read from socketpair: %s", strerror (errno));
+
+               return NULL;
+       }
+
+       return NULL;
+}
+#endif
+/**
+ * @method rspamd_fann:train_threaded(inputs, outputs, callback, event_base, {params})
+ * Trains neural network with batch of samples. Inputs and outputs should be tables of
+ * equal size, each row in table should be N inputs and M outputs, e.g.
+ *     {{0, 1, 1}, ...} -> {{0}, {1} ...}
+ * @param {table} inputs input samples
+ * @param {table} outputs output samples
+ * @param {callback} function that is called when train is completed
+ */
+static gint
+lua_fann_train_threaded (lua_State *L)
+{
+#ifndef WITH_FANN
+       return 0;
+#else
+       struct fann *f = rspamd_lua_check_fann (L, 1);
+       guint ninputs, noutputs, ndata, i, j;
+       struct lua_fann_train_cbdata *cbdata;
+       struct event_base *ev_base = lua_check_ev_base (L, 5);
+       GError *err = NULL;
+       const guint max_epochs_default = 1000;
+       const gdouble desired_mse_default = 0.0001;
+
+       if (f != NULL && lua_type (L, 2) == LUA_TTABLE &&
+                       lua_type (L, 3) == LUA_TTABLE && lua_type (L, 4) == LUA_TFUNCTION &&
+                       ev_base != NULL) {
+               /* First check sanity, call for table.getn for that */
+               ndata = rspamd_lua_table_size (L, 2);
+               ninputs = fann_get_num_input (f);
+               noutputs = fann_get_num_output (f);
+               cbdata = g_slice_alloc0 (sizeof (*cbdata));
+               cbdata->L = L;
+               cbdata->f = f;
+               cbdata->train = fann_create_train (ndata, ninputs, noutputs);
+               lua_pushvalue (L, 4);
+               cbdata->cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+
+               if (rspamd_socketpair (cbdata->pair) == -1) {
+                       msg_err ("cannot open socketpair: %s", strerror (errno));
+                       cbdata->pair[0] = -1;
+                       cbdata->pair[1] = -1;
+                       goto err;
+               }
+
+               for (i = 0; i < ndata; i ++) {
+                       lua_rawgeti (L, 2, i + 1);
+
+                       if (rspamd_lua_table_size (L, -1) != ninputs) {
+                               msg_err ("invalid number of inputs: %d, %d expected",
+                                               rspamd_lua_table_size (L, -1), ninputs);
+                               goto err;
+                       }
+
+                       for (j = 0; j < ninputs; j ++) {
+                               lua_rawgeti (L, -1, j + 1);
+                               cbdata->train->input[i][j] = lua_tonumber (L, -1);
+                               lua_pop (L, 1);
+                       }
+
+                       lua_pop (L, 1);
+                       lua_rawgeti (L, 3, i + 1);
+
+                       if (rspamd_lua_table_size (L, -1) != noutputs) {
+                               msg_err ("invalid number of outputs: %d, %d expected",
+                                               rspamd_lua_table_size (L, -1), noutputs);
+                               goto err;
+                       }
+
+                       for (j = 0; j < noutputs; j++) {
+                               lua_rawgeti (L, -1, j + 1);
+                               cbdata->train->output[i][j] = lua_tonumber (L, -1);
+                               lua_pop (L, 1);
+                       }
+               }
+
+               cbdata->max_epochs = max_epochs_default;
+               cbdata->desired_mse = desired_mse_default;
+
+               if (lua_type (L, 5) == LUA_TTABLE) {
+                       rspamd_lua_parse_table_arguments (L, 5, NULL,
+                                       "max_epochs=I;desired_mse=N",
+                                       &cbdata->max_epochs, &cbdata->desired_mse);
+               }
+
+               /* Now we can call training in a separate thread */
+               rspamd_socket_nonblocking (cbdata->pair[0]);
+               event_set (&cbdata->io, cbdata->pair[0], EV_READ, lua_fann_thread_notify,
+                               cbdata);
+               event_base_set (ev_base, &cbdata->io);
+               /* TODO: add timeout */
+               event_add (&cbdata->io, NULL);
+               cbdata->t = rspamd_create_thread ("fann train", lua_fann_train_thread,
+                               cbdata, &err);
+
+               if (cbdata->t == NULL) {
+                       msg_err ("cannot create training thread: %e", err);
+
+                       if (err) {
+                               g_error_free (err);
+                       }
+
+                       goto err;
+               }
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 0;
+
+err:
+       if (cbdata->pair[0] != -1) {
+               close (cbdata->pair[0]);
+       }
+       if (cbdata->pair[1] != -1) {
+               close (cbdata->pair[1]);
+       }
+
+       fann_destroy_train (cbdata->train);
+       luaL_unref (L, LUA_REGISTRYINDEX, cbdata->cbref);
+       g_slice_free1 (sizeof (*cbdata), cbdata);
+       return luaL_error (L, "invalid arguments");
+#endif
+}
+
 /**
  * @method rspamd_fann:test(inputs)
  * Tests neural network with samples. Inputs is a single sample of input data.