diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-10-08 13:42:56 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-10-08 16:35:42 +0100 |
commit | f98bb456ab710bc3fa99cdbe5383f39c00dcd480 (patch) | |
tree | c1df8bbe44fcdc08fbbc6694acf4fc0770ab2f18 | |
parent | b7be42fa1e48fb1afa050fd44976843d6defd466 (diff) | |
download | rspamd-f98bb456ab710bc3fa99cdbe5383f39c00dcd480.tar.gz rspamd-f98bb456ab710bc3fa99cdbe5383f39c00dcd480.zip |
[Feature] Add neural net serialization/deserialization
-rw-r--r-- | src/lua/lua_fann.c | 142 |
1 files changed, 139 insertions, 3 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c index 6df5d9470..3d15c6417 100644 --- a/src/lua/lua_fann.c +++ b/src/lua/lua_fann.c @@ -19,6 +19,8 @@ #include <fann.h> #endif +#include "unix-std.h" + /*** * @module rspamd_fann * This module enables [fann](http://libfann.github.io) interaction in rspamd @@ -31,7 +33,8 @@ */ LUA_FUNCTION_DEF (fann, is_enabled); LUA_FUNCTION_DEF (fann, create); -LUA_FUNCTION_DEF (fann, load); +LUA_FUNCTION_DEF (fann, load_file); +LUA_FUNCTION_DEF (fann, load_data); /* * Fann methods @@ -39,6 +42,7 @@ LUA_FUNCTION_DEF (fann, load); LUA_FUNCTION_DEF (fann, train); LUA_FUNCTION_DEF (fann, test); LUA_FUNCTION_DEF (fann, save); +LUA_FUNCTION_DEF (fann, data); LUA_FUNCTION_DEF (fann, get_inputs); LUA_FUNCTION_DEF (fann, get_outputs); LUA_FUNCTION_DEF (fann, dtor); @@ -46,7 +50,9 @@ LUA_FUNCTION_DEF (fann, dtor); static const struct luaL_reg fannlib_f[] = { LUA_INTERFACE_DEF (fann, is_enabled), LUA_INTERFACE_DEF (fann, create), - LUA_INTERFACE_DEF (fann, load), + LUA_INTERFACE_DEF (fann, load_file), + {"load", lua_fann_load_file}, + LUA_INTERFACE_DEF (fann, load_data), {NULL, NULL} }; @@ -54,6 +60,7 @@ static const struct luaL_reg fannlib_m[] = { LUA_INTERFACE_DEF (fann, train), LUA_INTERFACE_DEF (fann, test), LUA_INTERFACE_DEF (fann, save), + LUA_INTERFACE_DEF (fann, data), LUA_INTERFACE_DEF (fann, get_inputs), LUA_INTERFACE_DEF (fann, get_outputs), {"__gc", lua_fann_dtor}, @@ -141,7 +148,7 @@ lua_fann_create (lua_State *L) * @return {fann} fann object */ static gint -lua_fann_load (lua_State *L) +lua_fann_load_file (lua_State *L) { #ifndef WITH_FANN return 0; @@ -171,6 +178,135 @@ lua_fann_load (lua_State *L) #endif } +/*** + * @function rspamd_fann.load_data(data) + * Loads neural network from the data + * @param {string} file filename where fann is stored + * @return {fann} fann object + */ +static gint +lua_fann_load_data (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f, **pfann; + gint fd; + struct rspamd_lua_text *t; + gchar fpath[PATH_MAX]; + + if (lua_type (L, 1) == LUA_TUSERDATA) { + t = lua_check_text (L, 1); + + if (!t) { + return luaL_error (L, "text required"); + } + } + else { + t = g_alloca (sizeof (*t)); + t->start = lua_tolstring (L, 1, (gsize *)&t->len); + t->flags = 0; + } + + /* We need to save data to file because of libfann stupidity */ + rspamd_strlcpy (fpath, "/tmp/rspamd-fannXXXXXXXXXX", sizeof (fpath)); + fd = mkstemp (fpath); + + if (fd == -1) { + msg_warn ("cannot create tempfile: %s", strerror (errno)); + lua_pushnil (L); + } + else { + if (write (fd, t->start, t->len) == -1) { + msg_warn ("cannot write tempfile: %s", strerror (errno)); + lua_pushnil (L); + unlink (fpath); + close (fd); + + return 1; + } + + f = fann_create_from_file (fpath); + unlink (fpath); + close (fd); + + if (f != NULL) { + pfann = lua_newuserdata (L, sizeof (gpointer)); + *pfann = f; + rspamd_lua_setclass (L, "rspamd{fann}", -1); + } + else { + lua_pushnil (L); + } + } + + return 1; +#endif +} + +/*** + * @function rspamd_fann:data() + * Returns serialized neural network + * @return {rspamd_text} fann data + */ +static gint +lua_fann_data (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + gint fd; + struct rspamd_lua_text *res; + gchar fpath[PATH_MAX]; + gpointer map; + gsize sz; + + if (f == NULL) { + return luaL_error (L, "invalid arguments"); + } + + /* We need to save data to file because of libfann stupidity */ + rspamd_strlcpy (fpath, "/tmp/rspamd-fannXXXXXXXXXX", sizeof (fpath)); + fd = mkstemp (fpath); + + if (fd == -1) { + msg_warn ("cannot create tempfile: %s", strerror (errno)); + lua_pushnil (L); + } + else { + if (fann_save (f, fpath) == -1) { + msg_warn ("cannot write tempfile: %s", strerror (errno)); + lua_pushnil (L); + unlink (fpath); + close (fd); + + return 1; + } + + + (void)lseek (fd, 0, SEEK_SET); + map = rspamd_file_xmap (fpath, PROT_READ, &sz); + unlink (fpath); + close (fd); + + if (map != NULL) { + res = lua_newuserdata (L, sizeof (*res)); + res->len = sz; + res->start = map; + res->flags = RSPAMD_TEXT_FLAG_OWN|RSPAMD_TEXT_FLAG_MMAPED; + rspamd_lua_setclass (L, "rspamd{text}", -1); + } + else { + lua_pushnil (L); + } + + } + + return 1; +#endif +} + /** * @method rspamd_fann:train(inputs, outputs) |