summaryrefslogtreecommitdiffstats
path: root/src/lua/lua_fann.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-04 14:12:42 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-04 14:13:07 +0000
commit00cce0119a907a2414ccf38a7e4221ba75740795 (patch)
tree0df3fdf2819a72193b2dd24cbab85a5d251717e7 /src/lua/lua_fann.c
parent120575cd75e899da79b13fadb11e03f7e18f39b3 (diff)
downloadrspamd-00cce0119a907a2414ccf38a7e4221ba75740795.tar.gz
rspamd-00cce0119a907a2414ccf38a7e4221ba75740795.zip
[Feature] Implement FANN threaded learning
Diffstat (limited to 'src/lua/lua_fann.c')
-rw-r--r--src/lua/lua_fann.c220
1 files changed, 220 insertions, 0 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c
index 4f4ffbecd..fa1a27ae8 100644
--- a/src/lua/lua_fann.c
+++ b/src/lua/lua_fann.c
@@ -41,6 +41,7 @@ LUA_FUNCTION_DEF (fann, load_data);
* Fann methods
*/
LUA_FUNCTION_DEF (fann, train);
+LUA_FUNCTION_DEF (fann, train_threaded);
LUA_FUNCTION_DEF (fann, test);
LUA_FUNCTION_DEF (fann, save);
LUA_FUNCTION_DEF (fann, data);
@@ -62,6 +63,7 @@ static const struct luaL_reg fannlib_f[] = {
static const struct luaL_reg fannlib_m[] = {
LUA_INTERFACE_DEF (fann, train),
+ LUA_INTERFACE_DEF (fann, train_threaded),
LUA_INTERFACE_DEF (fann, test),
LUA_INTERFACE_DEF (fann, save),
LUA_INTERFACE_DEF (fann, data),
@@ -537,6 +539,224 @@ lua_fann_train (lua_State *L)
#endif
}
+#ifdef WITH_FANN
+struct lua_fann_train_cbdata {
+ lua_State *L;
+ gint pair[2];
+ struct fann_train_data *train;
+ struct fann *f;
+ gint cbref;
+ gdouble desired_mse;
+ guint max_epochs;
+ GThread *t;
+ struct event io;
+};
+
+struct lua_fann_train_reply {
+ gint errcode;
+ float mse;
+ gchar errmsg[128];
+};
+
+static void
+lua_fann_push_train_result (struct lua_fann_train_cbdata *cbdata,
+ gint errcode, float mse, const gchar *errmsg)
+{
+ lua_rawgeti (cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref);
+ lua_pushnumber (cbdata->L, errcode);
+ lua_pushstring (cbdata->L, errmsg);
+ lua_pushnumber (cbdata->L, mse);
+
+ if (lua_pcall (cbdata->L, 3, 0, 0) != 0) {
+ msg_err ("call to train callback failed: %s", lua_tostring (cbdata->L, -1));
+ lua_pop (cbdata->L, 1);
+ }
+}
+
+static void
+lua_fann_thread_notify (gint fd, short what, gpointer ud)
+{
+ struct lua_fann_train_cbdata *cbdata = ud;
+ struct lua_fann_train_reply rep;
+
+ if (read (cbdata->pair[0], &rep, sizeof (rep)) == -1) {
+ if (errno == EAGAIN || errno == EINTR) {
+ event_add (&cbdata->io, NULL);
+ return;
+ }
+
+ lua_fann_push_train_result (cbdata, errno, 0.0, strerror (errno));
+ }
+ else {
+ lua_fann_push_train_result (cbdata, rep.errcode, rep.mse, rep.errmsg);
+ }
+
+ write (cbdata->pair[0], "", 1);
+ g_thread_join (cbdata->t);
+ close (cbdata->pair[0]);
+ close (cbdata->pair[1]);
+
+ fann_destroy_train (cbdata->train);
+ luaL_unref (cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref);
+ g_slice_free1 (sizeof (*cbdata), cbdata);
+}
+
+static void *
+lua_fann_train_thread (void *ud)
+{
+ struct lua_fann_train_cbdata *cbdata = ud;
+ struct lua_fann_train_reply rep;
+ gchar repbuf[1];
+
+ msg_info ("start learning ANN, %d epochs are possible",
+ cbdata->max_epochs);
+ rspamd_socket_blocking (cbdata->pair[1]);
+ fann_train_on_data (cbdata->f, cbdata->train, cbdata->max_epochs, 0,
+ cbdata->desired_mse);
+ rep.errcode = 0;
+ rspamd_strlcpy (rep.errmsg, "OK", sizeof (rep.errmsg));
+ rep.mse = fann_get_MSE (cbdata->f);
+
+ if (write (cbdata->pair[1], &rep, sizeof (rep)) == -1) {
+ msg_err ("cannot write to socketpair: %s", strerror (errno));
+
+ return NULL;
+ }
+
+ if (read (cbdata->pair[1], repbuf, sizeof (repbuf)) == -1) {
+ msg_err ("cannot read from socketpair: %s", strerror (errno));
+
+ return NULL;
+ }
+
+ return NULL;
+}
+#endif
+/**
+ * @method rspamd_fann:train_threaded(inputs, outputs, callback, event_base, {params})
+ * Trains neural network with batch of samples. Inputs and outputs should be tables of
+ * equal size, each row in table should be N inputs and M outputs, e.g.
+ * {{0, 1, 1}, ...} -> {{0}, {1} ...}
+ * @param {table} inputs input samples
+ * @param {table} outputs output samples
+ * @param {callback} function that is called when train is completed
+ */
+static gint
+lua_fann_train_threaded (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ guint ninputs, noutputs, ndata, i, j;
+ struct lua_fann_train_cbdata *cbdata;
+ struct event_base *ev_base = lua_check_ev_base (L, 5);
+ GError *err = NULL;
+ const guint max_epochs_default = 1000;
+ const gdouble desired_mse_default = 0.0001;
+
+ if (f != NULL && lua_type (L, 2) == LUA_TTABLE &&
+ lua_type (L, 3) == LUA_TTABLE && lua_type (L, 4) == LUA_TFUNCTION &&
+ ev_base != NULL) {
+ /* First check sanity, call for table.getn for that */
+ ndata = rspamd_lua_table_size (L, 2);
+ ninputs = fann_get_num_input (f);
+ noutputs = fann_get_num_output (f);
+ cbdata = g_slice_alloc0 (sizeof (*cbdata));
+ cbdata->L = L;
+ cbdata->f = f;
+ cbdata->train = fann_create_train (ndata, ninputs, noutputs);
+ lua_pushvalue (L, 4);
+ cbdata->cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+
+ if (rspamd_socketpair (cbdata->pair) == -1) {
+ msg_err ("cannot open socketpair: %s", strerror (errno));
+ cbdata->pair[0] = -1;
+ cbdata->pair[1] = -1;
+ goto err;
+ }
+
+ for (i = 0; i < ndata; i ++) {
+ lua_rawgeti (L, 2, i + 1);
+
+ if (rspamd_lua_table_size (L, -1) != ninputs) {
+ msg_err ("invalid number of inputs: %d, %d expected",
+ rspamd_lua_table_size (L, -1), ninputs);
+ goto err;
+ }
+
+ for (j = 0; j < ninputs; j ++) {
+ lua_rawgeti (L, -1, j + 1);
+ cbdata->train->input[i][j] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1);
+ lua_rawgeti (L, 3, i + 1);
+
+ if (rspamd_lua_table_size (L, -1) != noutputs) {
+ msg_err ("invalid number of outputs: %d, %d expected",
+ rspamd_lua_table_size (L, -1), noutputs);
+ goto err;
+ }
+
+ for (j = 0; j < noutputs; j++) {
+ lua_rawgeti (L, -1, j + 1);
+ cbdata->train->output[i][j] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+ }
+
+ cbdata->max_epochs = max_epochs_default;
+ cbdata->desired_mse = desired_mse_default;
+
+ if (lua_type (L, 5) == LUA_TTABLE) {
+ rspamd_lua_parse_table_arguments (L, 5, NULL,
+ "max_epochs=I;desired_mse=N",
+ &cbdata->max_epochs, &cbdata->desired_mse);
+ }
+
+ /* Now we can call training in a separate thread */
+ rspamd_socket_nonblocking (cbdata->pair[0]);
+ event_set (&cbdata->io, cbdata->pair[0], EV_READ, lua_fann_thread_notify,
+ cbdata);
+ event_base_set (ev_base, &cbdata->io);
+ /* TODO: add timeout */
+ event_add (&cbdata->io, NULL);
+ cbdata->t = rspamd_create_thread ("fann train", lua_fann_train_thread,
+ cbdata, &err);
+
+ if (cbdata->t == NULL) {
+ msg_err ("cannot create training thread: %e", err);
+
+ if (err) {
+ g_error_free (err);
+ }
+
+ goto err;
+ }
+ }
+ else {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ return 0;
+
+err:
+ if (cbdata->pair[0] != -1) {
+ close (cbdata->pair[0]);
+ }
+ if (cbdata->pair[1] != -1) {
+ close (cbdata->pair[1]);
+ }
+
+ fann_destroy_train (cbdata->train);
+ luaL_unref (L, LUA_REGISTRYINDEX, cbdata->cbref);
+ g_slice_free1 (sizeof (*cbdata), cbdata);
+ return luaL_error (L, "invalid arguments");
+#endif
+}
+
/**
* @method rspamd_fann:test(inputs)
* Tests neural network with samples. Inputs is a single sample of input data.