aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-10-08 13:42:56 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-10-08 16:35:42 +0100
commitf98bb456ab710bc3fa99cdbe5383f39c00dcd480 (patch)
treec1df8bbe44fcdc08fbbc6694acf4fc0770ab2f18
parentb7be42fa1e48fb1afa050fd44976843d6defd466 (diff)
downloadrspamd-f98bb456ab710bc3fa99cdbe5383f39c00dcd480.tar.gz
rspamd-f98bb456ab710bc3fa99cdbe5383f39c00dcd480.zip
[Feature] Add neural net serialization/deserialization
-rw-r--r--src/lua/lua_fann.c142
1 files changed, 139 insertions, 3 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c
index 6df5d9470..3d15c6417 100644
--- a/src/lua/lua_fann.c
+++ b/src/lua/lua_fann.c
@@ -19,6 +19,8 @@
#include <fann.h>
#endif
+#include "unix-std.h"
+
/***
* @module rspamd_fann
* This module enables [fann](http://libfann.github.io) interaction in rspamd
@@ -31,7 +33,8 @@
*/
LUA_FUNCTION_DEF (fann, is_enabled);
LUA_FUNCTION_DEF (fann, create);
-LUA_FUNCTION_DEF (fann, load);
+LUA_FUNCTION_DEF (fann, load_file);
+LUA_FUNCTION_DEF (fann, load_data);
/*
* Fann methods
@@ -39,6 +42,7 @@ LUA_FUNCTION_DEF (fann, load);
LUA_FUNCTION_DEF (fann, train);
LUA_FUNCTION_DEF (fann, test);
LUA_FUNCTION_DEF (fann, save);
+LUA_FUNCTION_DEF (fann, data);
LUA_FUNCTION_DEF (fann, get_inputs);
LUA_FUNCTION_DEF (fann, get_outputs);
LUA_FUNCTION_DEF (fann, dtor);
@@ -46,7 +50,9 @@ LUA_FUNCTION_DEF (fann, dtor);
static const struct luaL_reg fannlib_f[] = {
LUA_INTERFACE_DEF (fann, is_enabled),
LUA_INTERFACE_DEF (fann, create),
- LUA_INTERFACE_DEF (fann, load),
+ LUA_INTERFACE_DEF (fann, load_file),
+ {"load", lua_fann_load_file},
+ LUA_INTERFACE_DEF (fann, load_data),
{NULL, NULL}
};
@@ -54,6 +60,7 @@ static const struct luaL_reg fannlib_m[] = {
LUA_INTERFACE_DEF (fann, train),
LUA_INTERFACE_DEF (fann, test),
LUA_INTERFACE_DEF (fann, save),
+ LUA_INTERFACE_DEF (fann, data),
LUA_INTERFACE_DEF (fann, get_inputs),
LUA_INTERFACE_DEF (fann, get_outputs),
{"__gc", lua_fann_dtor},
@@ -141,7 +148,7 @@ lua_fann_create (lua_State *L)
* @return {fann} fann object
*/
static gint
-lua_fann_load (lua_State *L)
+lua_fann_load_file (lua_State *L)
{
#ifndef WITH_FANN
return 0;
@@ -171,6 +178,135 @@ lua_fann_load (lua_State *L)
#endif
}
+/***
+ * @function rspamd_fann.load_data(data)
+ * Loads neural network from the data
+ * @param {string} file filename where fann is stored
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_load_data (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f, **pfann;
+ gint fd;
+ struct rspamd_lua_text *t;
+ gchar fpath[PATH_MAX];
+
+ if (lua_type (L, 1) == LUA_TUSERDATA) {
+ t = lua_check_text (L, 1);
+
+ if (!t) {
+ return luaL_error (L, "text required");
+ }
+ }
+ else {
+ t = g_alloca (sizeof (*t));
+ t->start = lua_tolstring (L, 1, (gsize *)&t->len);
+ t->flags = 0;
+ }
+
+ /* We need to save data to file because of libfann stupidity */
+ rspamd_strlcpy (fpath, "/tmp/rspamd-fannXXXXXXXXXX", sizeof (fpath));
+ fd = mkstemp (fpath);
+
+ if (fd == -1) {
+ msg_warn ("cannot create tempfile: %s", strerror (errno));
+ lua_pushnil (L);
+ }
+ else {
+ if (write (fd, t->start, t->len) == -1) {
+ msg_warn ("cannot write tempfile: %s", strerror (errno));
+ lua_pushnil (L);
+ unlink (fpath);
+ close (fd);
+
+ return 1;
+ }
+
+ f = fann_create_from_file (fpath);
+ unlink (fpath);
+ close (fd);
+
+ if (f != NULL) {
+ pfann = lua_newuserdata (L, sizeof (gpointer));
+ *pfann = f;
+ rspamd_lua_setclass (L, "rspamd{fann}", -1);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @function rspamd_fann:data()
+ * Returns serialized neural network
+ * @return {rspamd_text} fann data
+ */
+static gint
+lua_fann_data (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ gint fd;
+ struct rspamd_lua_text *res;
+ gchar fpath[PATH_MAX];
+ gpointer map;
+ gsize sz;
+
+ if (f == NULL) {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ /* We need to save data to file because of libfann stupidity */
+ rspamd_strlcpy (fpath, "/tmp/rspamd-fannXXXXXXXXXX", sizeof (fpath));
+ fd = mkstemp (fpath);
+
+ if (fd == -1) {
+ msg_warn ("cannot create tempfile: %s", strerror (errno));
+ lua_pushnil (L);
+ }
+ else {
+ if (fann_save (f, fpath) == -1) {
+ msg_warn ("cannot write tempfile: %s", strerror (errno));
+ lua_pushnil (L);
+ unlink (fpath);
+ close (fd);
+
+ return 1;
+ }
+
+
+ (void)lseek (fd, 0, SEEK_SET);
+ map = rspamd_file_xmap (fpath, PROT_READ, &sz);
+ unlink (fpath);
+ close (fd);
+
+ if (map != NULL) {
+ res = lua_newuserdata (L, sizeof (*res));
+ res->len = sz;
+ res->start = map;
+ res->flags = RSPAMD_TEXT_FLAG_OWN|RSPAMD_TEXT_FLAG_MMAPED;
+ rspamd_lua_setclass (L, "rspamd{text}", -1);
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ }
+
+ return 1;
+#endif
+}
+
/**
* @method rspamd_fann:train(inputs, outputs)