summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--doc/Makefile10
-rw-r--r--src/CMakeLists.txt2
-rw-r--r--src/lua/CMakeLists.txt5
-rw-r--r--src/lua/lua_common.c19
-rw-r--r--src/lua/lua_common.h6
-rw-r--r--src/lua/lua_fann.c435
6 files changed, 470 insertions, 7 deletions
diff --git a/doc/Makefile b/doc/Makefile
index f24fc2551..02b3258f8 100644
--- a/doc/Makefile
+++ b/doc/Makefile
@@ -8,15 +8,15 @@ all: man
man: rspamd.8 rspamc.1 rspamadm.1
rspamd.8: rspamd.8.md
- $(PANDOC) -s -f markdown -t man -o rspamd.8 rspamd.8.md
+ $(PANDOC) -s -f markdown -t man -o rspamd.8 rspamd.8.md
rspamc.1: rspamc.1.md
$(PANDOC) -s -f markdown -t man -o rspamc.1 rspamc.1.md
rspamadm.1: rspamadm.1.md
$(PANDOC) -s -f markdown -t man -o rspamadm.1 rspamadm.1.md
-
+
lua-doc: lua_regexp lua_ip lua_config lua_task lua_ucl lua_http lua_trie \
lua_dns lua_redis lua_upstream lua_expression lua_mimepart lua_logger lua_url \
- lua_tcp lua_mempool lua_html lua_util
+ lua_tcp lua_mempool lua_html lua_util lua_fann
lua_regexp: ../src/lua/lua_regexp.c
$(LUADOC) < ../src/lua/lua_regexp.c > markdown/lua/regexp.md
@@ -53,4 +53,6 @@ lua_mempool: ../src/lua/lua_mempool.c
lua_html: ../src/lua/lua_html.c
$(LUADOC) < ../src/lua/lua_html.c > markdown/lua/html.md
lua_util: ../src/lua/lua_util.c
- $(LUADOC) < ../src/lua/lua_util.c > markdown/lua/util.md \ No newline at end of file
+ $(LUADOC) < ../src/lua/lua_util.c > markdown/lua/util.md
+lua_fann: ../src/lua/lua_fann.c
+ $(LUADOC) < ../src/lua/lua_fann.c > markdown/lua/fann.md
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 385de9eb8..a37256c41 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -90,7 +90,7 @@ SET(PLUGINSSRC plugins/surbl.c
plugins/fuzzy_check.c
plugins/spf.c
plugins/dkim_check.c
- libserver/rspamd_control.c)
+ libserver/rspamd_control.c lua/lua_fann.c)
SET(MODULES_LIST surbl regexp chartable fuzzy_check spf dkim)
SET(WORKERS_LIST normal controller smtp_proxy fuzzy lua http_proxy)
diff --git a/src/lua/CMakeLists.txt b/src/lua/CMakeLists.txt
index ad526d534..4d74e7752 100644
--- a/src/lua/CMakeLists.txt
+++ b/src/lua/CMakeLists.txt
@@ -23,6 +23,7 @@ SET(LUASRC ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_url.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_util.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_tcp.c
- ${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c)
+ ${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c
+ ${CMAKE_CURRENT_SOURCE_DIR}/lua_fann.c)
-SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE) \ No newline at end of file
+SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c
index 0cbec0408..e475ace83 100644
--- a/src/lua/lua_common.c
+++ b/src/lua/lua_common.c
@@ -247,6 +247,7 @@ rspamd_lua_init ()
luaopen_util (L);
luaopen_tcp (L);
luaopen_html (L);
+ luaopen_fann (L);
rspamd_lua_add_preload (L, "ucl", luaopen_ucl);
@@ -944,3 +945,21 @@ rspamd_lua_traceback (lua_State *L)
return 1;
}
+
+guint
+rspamd_lua_table_size (lua_State *L, gint tbl_pos)
+{
+ guint tbl_size = 0;
+
+ if (!lua_istable (L, tbl_pos)) {
+ return 0;
+ }
+
+#if LUA_VERSION_NUM >= 502
+ tbl_size = lua_rawlen (L, tbl_pos);
+#else
+ tbl_size = lua_objlen (L, tbl_pos);
+#endif
+
+ return tbl_size;
+}
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index f51aee731..b41d33811 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -222,6 +222,7 @@ void luaopen_text (lua_State *L);
void luaopen_util (lua_State * L);
void luaopen_tcp (lua_State * L);
void luaopen_html (lua_State * L);
+void luaopen_fann (lua_State *L);
gint rspamd_lua_call_filter (const gchar *function, struct rspamd_task *task);
gint rspamd_lua_call_chain_filter (const gchar *function,
@@ -289,5 +290,10 @@ gboolean rspamd_lua_parse_table_arguments (lua_State *L, gint pos,
gint rspamd_lua_traceback (lua_State *L);
+
+/**
+ * Returns size of table at position `tbl_pos`
+ */
+guint rspamd_lua_table_size (lua_State *L, gint tbl_pos);
#endif /* WITH_LUA */
#endif /* RSPAMD_LUA_H */
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c
new file mode 100644
index 000000000..90c037d35
--- /dev/null
+++ b/src/lua/lua_fann.c
@@ -0,0 +1,435 @@
+/*
+ * Copyright (c) 2015, Vsevolod Stakhov
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "lua_common.h"
+
+#ifdef WITH_FANN
+#include <fann.h>
+#endif
+
+/***
+ * @module rspamd_fann
+ * This module enables [fann](http://libfann.github.io) interaction in rspamd
+ * Please note, that this module works merely if you have `ENABLE_FANN=ON` option
+ * definition when building rspamd
+ */
+
+/*
+ * Fann functions
+ */
+LUA_FUNCTION_DEF (fann, is_enabled);
+LUA_FUNCTION_DEF (fann, create);
+LUA_FUNCTION_DEF (fann, load);
+
+/*
+ * Fann methods
+ */
+LUA_FUNCTION_DEF (fann, train);
+LUA_FUNCTION_DEF (fann, test);
+LUA_FUNCTION_DEF (fann, save);
+LUA_FUNCTION_DEF (fann, get_inputs);
+LUA_FUNCTION_DEF (fann, get_outputs);
+LUA_FUNCTION_DEF (fann, dtor);
+
+static const struct luaL_reg fannlib_f[] = {
+ LUA_INTERFACE_DEF (fann, is_enabled),
+ LUA_INTERFACE_DEF (fann, create),
+ LUA_INTERFACE_DEF (fann, load),
+ {NULL, NULL}
+};
+
+static const struct luaL_reg fannlib_m[] = {
+ LUA_INTERFACE_DEF (fann, train),
+ LUA_INTERFACE_DEF (fann, test),
+ LUA_INTERFACE_DEF (fann, save),
+ LUA_INTERFACE_DEF (fann, get_inputs),
+ LUA_INTERFACE_DEF (fann, get_outputs),
+ {"__gc", lua_fann_dtor},
+ {"__tostring", rspamd_lua_class_tostring},
+ {NULL, NULL}
+};
+
+#ifdef WITH_FANN
+struct fann *
+rspamd_lua_check_fann (lua_State *L, gint pos)
+{
+ void *ud = luaL_checkudata (L, pos, "rspamd{fann}");
+ luaL_argcheck (L, ud != NULL, pos, "'fann' expected");
+ return ud ? *((struct fann **) ud) : NULL;
+}
+#endif
+
+/***
+ * @function rspamd_fann.is_enabled()
+ * Checks if fann is enabled for this rspamd build
+ * @return {boolean} true if fann is enabled
+ */
+static gint
+lua_fann_is_enabled (lua_State *L)
+{
+#ifdef WITH_FANN
+ lua_pushboolean (L, true);
+#else
+ lua_pushboolean (L, false);
+#endif
+ return 1;
+}
+
+/***
+ * @function rspamd_fann.create(nlayers, [layer1, ... layern])
+ * Creates new neural network with `nlayers` that contains `layer1`...`layern`
+ * neurons in each layer
+ * @param {number} nlayers number of layers
+ * @param {number} layerI number of neurons in each layer
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_create (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f, **pfann;
+ guint nlayers, *layers, i;
+
+ nlayers = luaL_checknumber (L, 1);
+
+ if (nlayers > 0) {
+ layers = g_malloc (nlayers * sizeof (layers[0]));
+
+ for (i = 0; i < nlayers; i ++) {
+ layers[i] = luaL_checknumber (L, i + 2);
+ }
+
+ f = fann_create_standard_array (nlayers, layers);
+
+ if (f != NULL) {
+ pfann = lua_newuserdata (L, sizeof (gpointer));
+ *pfann = f;
+ rspamd_lua_setclass (L, "rspamd{fann}", -1);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @function rspamd_fann.load(file)
+ * Loads neural network from the file
+ * @param {string} file filename where fann is stored
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_load (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f, **pfann;
+ const gchar *fname;
+
+ fname = luaL_checkstring (L, 1);
+
+ if (fname != NULL) {
+ f = fann_create_from_file (fname);
+
+ if (f != NULL) {
+ pfann = lua_newuserdata (L, sizeof (gpointer));
+ *pfann = f;
+ rspamd_lua_setclass (L, "rspamd{fann}", -1);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+
+/**
+ * @method rspamd_fann:train(inputs, outputs)
+ * Trains neural network with samples. Inputs and outputs should be tables of
+ * equal size, each row in table should be N inputs and M outputs, e.g.
+ * {0, 1, 1} -> {0}
+ * {1, 0, 0} -> {1}
+ * @param {table/table} inputs input samples
+ * @param {table/table} outputs output samples
+ * @return {number} number of samples learned
+ */
+static gint
+lua_fann_train (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ guint ninputs, noutputs, i, j, cur_len;
+ float *cur_input, *cur_output;
+ gint ret = 0;
+
+ if (f != NULL) {
+ /* First check sanity, call for table.getn for that */
+ ninputs = rspamd_lua_table_size (L, 2);
+ noutputs = rspamd_lua_table_size (L, 3);
+
+ if (ninputs != noutputs) {
+ msg_err ("bad number of inputs(%d) and output(%d) args for train",
+ ninputs, noutputs);
+ }
+ else {
+ for (i = 0; i < ninputs; i ++) {
+ /* Push table with inputs */
+ lua_rawgeti (L, 2, i + 1);
+
+ cur_len = rspamd_lua_table_size (L, -1);
+
+ if (cur_len != fann_get_num_input (f)) {
+ msg_err (
+ "bad number of input samples: %d, %d expected",
+ cur_len,
+ fann_get_num_input (f));
+ lua_pop (L, 1);
+ continue;
+ }
+
+ cur_input = g_malloc (cur_len * sizeof (gint));
+
+ for (j = 0; j < cur_len; j ++) {
+ lua_rawgeti (L, -1, j + 1);
+ cur_input[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1); /* Inputs table */
+
+ /* Push table with outputs */
+ lua_rawgeti (L, 3, i + 1);
+
+ cur_len = rspamd_lua_table_size (L, -1);
+
+ if (cur_len != fann_get_num_output (f)) {
+ msg_err (
+ "bad number of output samples: %d, %d expected",
+ cur_len,
+ fann_get_num_output (f));
+ lua_pop (L, 1);
+ g_free (cur_input);
+ continue;
+ }
+
+ cur_output = g_malloc (cur_len * sizeof (gint));
+
+ for (j = 0; j < cur_len; j++) {
+ lua_rawgeti (L, -1, j + 1);
+ cur_output[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1); /* Outputs table */
+
+ fann_train (f, cur_input, cur_output);
+ g_free (cur_input);
+ g_free (cur_output);
+ ret ++;
+ }
+ }
+ }
+
+ lua_pushnumber (L, ret);
+
+ return 1;
+#endif
+}
+
+/**
+ * @method rspamd_fann:test(inputs)
+ * Tests neural network with samples. Inputs is a single sample of input data.
+ * The function returns table of results, e.g.:
+ * {0, 1, 1} -> {0}
+ * @param {table} inputs input sample
+ * @return {table/number} outputs values
+ */
+static gint
+lua_fann_test (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ guint ninputs, noutputs, i;
+ float *cur_input, *cur_output;
+
+ if (f != NULL) {
+ /* First check sanity, call for table.getn for that */
+ ninputs = rspamd_lua_table_size (L, 2);
+ cur_input = g_malloc (ninputs * sizeof (gint));
+
+ for (i = 0; i < ninputs; i++) {
+ lua_rawgeti (L, 2, i + 1);
+ cur_input[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ cur_output = fann_run (f, cur_input);
+ noutputs = fann_get_num_output (f);
+ lua_createtable (L, noutputs, 0);
+
+ for (i = 0; i < noutputs; i ++) {
+ lua_pushnumber (L, cur_output[i]);
+ lua_rawseti (L, -2, i + 1);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:get_inputs()
+ * Returns number of inputs for neural network
+ * @return {number} number of inputs
+ */
+static gint
+lua_fann_get_inputs (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+
+ if (f != NULL) {
+ lua_pushnumber (L, fann_get_num_input (f));
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:get_outputs()
+ * Returns number of outputs for neural network
+ * @return {number} number of outputs
+ */
+static gint
+lua_fann_get_outputs (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+
+ if (f != NULL) {
+ lua_pushnumber (L, fann_get_num_output (f));
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:save(fname)
+ * Save fann to file named 'fname'
+ * @param {string} fname filename to save fann into
+ * @return {boolean} true if ann has been saved
+ */
+static gint
+lua_fann_save (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ const gchar *fname = luaL_checkstring (L, 2);
+
+ if (f != NULL && fname != NULL) {
+ if (fann_save (f, fname) == 0) {
+ lua_pushboolean (L, true);
+ }
+ else {
+ msg_err ("cannot save ANN to %s: %s", fname, strerror (errno));
+ lua_pushboolean (L, false);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+static gint
+lua_fann_dtor (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+
+ if (f) {
+ fann_destroy (f);
+ }
+
+ return 0;
+#endif
+}
+
+static gint
+lua_load_fann (lua_State * L)
+{
+ lua_newtable (L);
+ luaL_register (L, NULL, fannlib_f);
+
+ return 1;
+}
+
+void
+luaopen_fann (lua_State * L)
+{
+ rspamd_lua_new_class (L, "rspamd{fann}", fannlib_m);
+ lua_pop (L, 1);
+
+ rspamd_lua_add_preload (L, "rspamd_fann", lua_load_fann);
+}