diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-11-04 14:12:42 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-11-04 14:13:07 +0000 |
commit | 00cce0119a907a2414ccf38a7e4221ba75740795 (patch) | |
tree | 0df3fdf2819a72193b2dd24cbab85a5d251717e7 /src/lua/lua_fann.c | |
parent | 120575cd75e899da79b13fadb11e03f7e18f39b3 (diff) | |
download | rspamd-00cce0119a907a2414ccf38a7e4221ba75740795.tar.gz rspamd-00cce0119a907a2414ccf38a7e4221ba75740795.zip |
[Feature] Implement FANN threaded learning
Diffstat (limited to 'src/lua/lua_fann.c')
-rw-r--r-- | src/lua/lua_fann.c | 220 |
1 files changed, 220 insertions, 0 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c index 4f4ffbecd..fa1a27ae8 100644 --- a/src/lua/lua_fann.c +++ b/src/lua/lua_fann.c @@ -41,6 +41,7 @@ LUA_FUNCTION_DEF (fann, load_data); * Fann methods */ LUA_FUNCTION_DEF (fann, train); +LUA_FUNCTION_DEF (fann, train_threaded); LUA_FUNCTION_DEF (fann, test); LUA_FUNCTION_DEF (fann, save); LUA_FUNCTION_DEF (fann, data); @@ -62,6 +63,7 @@ static const struct luaL_reg fannlib_f[] = { static const struct luaL_reg fannlib_m[] = { LUA_INTERFACE_DEF (fann, train), + LUA_INTERFACE_DEF (fann, train_threaded), LUA_INTERFACE_DEF (fann, test), LUA_INTERFACE_DEF (fann, save), LUA_INTERFACE_DEF (fann, data), @@ -537,6 +539,224 @@ lua_fann_train (lua_State *L) #endif } +#ifdef WITH_FANN +struct lua_fann_train_cbdata { + lua_State *L; + gint pair[2]; + struct fann_train_data *train; + struct fann *f; + gint cbref; + gdouble desired_mse; + guint max_epochs; + GThread *t; + struct event io; +}; + +struct lua_fann_train_reply { + gint errcode; + float mse; + gchar errmsg[128]; +}; + +static void +lua_fann_push_train_result (struct lua_fann_train_cbdata *cbdata, + gint errcode, float mse, const gchar *errmsg) +{ + lua_rawgeti (cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref); + lua_pushnumber (cbdata->L, errcode); + lua_pushstring (cbdata->L, errmsg); + lua_pushnumber (cbdata->L, mse); + + if (lua_pcall (cbdata->L, 3, 0, 0) != 0) { + msg_err ("call to train callback failed: %s", lua_tostring (cbdata->L, -1)); + lua_pop (cbdata->L, 1); + } +} + +static void +lua_fann_thread_notify (gint fd, short what, gpointer ud) +{ + struct lua_fann_train_cbdata *cbdata = ud; + struct lua_fann_train_reply rep; + + if (read (cbdata->pair[0], &rep, sizeof (rep)) == -1) { + if (errno == EAGAIN || errno == EINTR) { + event_add (&cbdata->io, NULL); + return; + } + + lua_fann_push_train_result (cbdata, errno, 0.0, strerror (errno)); + } + else { + lua_fann_push_train_result (cbdata, rep.errcode, rep.mse, rep.errmsg); + } + + write (cbdata->pair[0], "", 1); + g_thread_join (cbdata->t); + close (cbdata->pair[0]); + close (cbdata->pair[1]); + + fann_destroy_train (cbdata->train); + luaL_unref (cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref); + g_slice_free1 (sizeof (*cbdata), cbdata); +} + +static void * +lua_fann_train_thread (void *ud) +{ + struct lua_fann_train_cbdata *cbdata = ud; + struct lua_fann_train_reply rep; + gchar repbuf[1]; + + msg_info ("start learning ANN, %d epochs are possible", + cbdata->max_epochs); + rspamd_socket_blocking (cbdata->pair[1]); + fann_train_on_data (cbdata->f, cbdata->train, cbdata->max_epochs, 0, + cbdata->desired_mse); + rep.errcode = 0; + rspamd_strlcpy (rep.errmsg, "OK", sizeof (rep.errmsg)); + rep.mse = fann_get_MSE (cbdata->f); + + if (write (cbdata->pair[1], &rep, sizeof (rep)) == -1) { + msg_err ("cannot write to socketpair: %s", strerror (errno)); + + return NULL; + } + + if (read (cbdata->pair[1], repbuf, sizeof (repbuf)) == -1) { + msg_err ("cannot read from socketpair: %s", strerror (errno)); + + return NULL; + } + + return NULL; +} +#endif +/** + * @method rspamd_fann:train_threaded(inputs, outputs, callback, event_base, {params}) + * Trains neural network with batch of samples. Inputs and outputs should be tables of + * equal size, each row in table should be N inputs and M outputs, e.g. + * {{0, 1, 1}, ...} -> {{0}, {1} ...} + * @param {table} inputs input samples + * @param {table} outputs output samples + * @param {callback} function that is called when train is completed + */ +static gint +lua_fann_train_threaded (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + guint ninputs, noutputs, ndata, i, j; + struct lua_fann_train_cbdata *cbdata; + struct event_base *ev_base = lua_check_ev_base (L, 5); + GError *err = NULL; + const guint max_epochs_default = 1000; + const gdouble desired_mse_default = 0.0001; + + if (f != NULL && lua_type (L, 2) == LUA_TTABLE && + lua_type (L, 3) == LUA_TTABLE && lua_type (L, 4) == LUA_TFUNCTION && + ev_base != NULL) { + /* First check sanity, call for table.getn for that */ + ndata = rspamd_lua_table_size (L, 2); + ninputs = fann_get_num_input (f); + noutputs = fann_get_num_output (f); + cbdata = g_slice_alloc0 (sizeof (*cbdata)); + cbdata->L = L; + cbdata->f = f; + cbdata->train = fann_create_train (ndata, ninputs, noutputs); + lua_pushvalue (L, 4); + cbdata->cbref = luaL_ref (L, LUA_REGISTRYINDEX); + + if (rspamd_socketpair (cbdata->pair) == -1) { + msg_err ("cannot open socketpair: %s", strerror (errno)); + cbdata->pair[0] = -1; + cbdata->pair[1] = -1; + goto err; + } + + for (i = 0; i < ndata; i ++) { + lua_rawgeti (L, 2, i + 1); + + if (rspamd_lua_table_size (L, -1) != ninputs) { + msg_err ("invalid number of inputs: %d, %d expected", + rspamd_lua_table_size (L, -1), ninputs); + goto err; + } + + for (j = 0; j < ninputs; j ++) { + lua_rawgeti (L, -1, j + 1); + cbdata->train->input[i][j] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + lua_pop (L, 1); + lua_rawgeti (L, 3, i + 1); + + if (rspamd_lua_table_size (L, -1) != noutputs) { + msg_err ("invalid number of outputs: %d, %d expected", + rspamd_lua_table_size (L, -1), noutputs); + goto err; + } + + for (j = 0; j < noutputs; j++) { + lua_rawgeti (L, -1, j + 1); + cbdata->train->output[i][j] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + } + + cbdata->max_epochs = max_epochs_default; + cbdata->desired_mse = desired_mse_default; + + if (lua_type (L, 5) == LUA_TTABLE) { + rspamd_lua_parse_table_arguments (L, 5, NULL, + "max_epochs=I;desired_mse=N", + &cbdata->max_epochs, &cbdata->desired_mse); + } + + /* Now we can call training in a separate thread */ + rspamd_socket_nonblocking (cbdata->pair[0]); + event_set (&cbdata->io, cbdata->pair[0], EV_READ, lua_fann_thread_notify, + cbdata); + event_base_set (ev_base, &cbdata->io); + /* TODO: add timeout */ + event_add (&cbdata->io, NULL); + cbdata->t = rspamd_create_thread ("fann train", lua_fann_train_thread, + cbdata, &err); + + if (cbdata->t == NULL) { + msg_err ("cannot create training thread: %e", err); + + if (err) { + g_error_free (err); + } + + goto err; + } + } + else { + return luaL_error (L, "invalid arguments"); + } + + return 0; + +err: + if (cbdata->pair[0] != -1) { + close (cbdata->pair[0]); + } + if (cbdata->pair[1] != -1) { + close (cbdata->pair[1]); + } + + fann_destroy_train (cbdata->train); + luaL_unref (L, LUA_REGISTRYINDEX, cbdata->cbref); + g_slice_free1 (sizeof (*cbdata), cbdata); + return luaL_error (L, "invalid arguments"); +#endif +} + /** * @method rspamd_fann:test(inputs) * Tests neural network with samples. Inputs is a single sample of input data. |