diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/lua/CMakeLists.txt | 5 | ||||
-rw-r--r-- | src/lua/lua_common.c | 19 | ||||
-rw-r--r-- | src/lua/lua_common.h | 6 | ||||
-rw-r--r-- | src/lua/lua_fann.c | 435 |
5 files changed, 464 insertions, 3 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 385de9eb8..a37256c41 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,7 +90,7 @@ SET(PLUGINSSRC plugins/surbl.c plugins/fuzzy_check.c plugins/spf.c plugins/dkim_check.c - libserver/rspamd_control.c) + libserver/rspamd_control.c lua/lua_fann.c) SET(MODULES_LIST surbl regexp chartable fuzzy_check spf dkim) SET(WORKERS_LIST normal controller smtp_proxy fuzzy lua http_proxy) diff --git a/src/lua/CMakeLists.txt b/src/lua/CMakeLists.txt index ad526d534..4d74e7752 100644 --- a/src/lua/CMakeLists.txt +++ b/src/lua/CMakeLists.txt @@ -23,6 +23,7 @@ SET(LUASRC ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c ${CMAKE_CURRENT_SOURCE_DIR}/lua_url.c ${CMAKE_CURRENT_SOURCE_DIR}/lua_util.c ${CMAKE_CURRENT_SOURCE_DIR}/lua_tcp.c - ${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c) + ${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_fann.c) -SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
\ No newline at end of file +SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE) diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c index 0cbec0408..e475ace83 100644 --- a/src/lua/lua_common.c +++ b/src/lua/lua_common.c @@ -247,6 +247,7 @@ rspamd_lua_init () luaopen_util (L); luaopen_tcp (L); luaopen_html (L); + luaopen_fann (L); rspamd_lua_add_preload (L, "ucl", luaopen_ucl); @@ -944,3 +945,21 @@ rspamd_lua_traceback (lua_State *L) return 1; } + +guint +rspamd_lua_table_size (lua_State *L, gint tbl_pos) +{ + guint tbl_size = 0; + + if (!lua_istable (L, tbl_pos)) { + return 0; + } + +#if LUA_VERSION_NUM >= 502 + tbl_size = lua_rawlen (L, tbl_pos); +#else + tbl_size = lua_objlen (L, tbl_pos); +#endif + + return tbl_size; +} diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index f51aee731..b41d33811 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -222,6 +222,7 @@ void luaopen_text (lua_State *L); void luaopen_util (lua_State * L); void luaopen_tcp (lua_State * L); void luaopen_html (lua_State * L); +void luaopen_fann (lua_State *L); gint rspamd_lua_call_filter (const gchar *function, struct rspamd_task *task); gint rspamd_lua_call_chain_filter (const gchar *function, @@ -289,5 +290,10 @@ gboolean rspamd_lua_parse_table_arguments (lua_State *L, gint pos, gint rspamd_lua_traceback (lua_State *L); + +/** + * Returns size of table at position `tbl_pos` + */ +guint rspamd_lua_table_size (lua_State *L, gint tbl_pos); #endif /* WITH_LUA */ #endif /* RSPAMD_LUA_H */ diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c new file mode 100644 index 000000000..90c037d35 --- /dev/null +++ b/src/lua/lua_fann.c @@ -0,0 +1,435 @@ +/* + * Copyright (c) 2015, Vsevolod Stakhov + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "lua_common.h" + +#ifdef WITH_FANN +#include <fann.h> +#endif + +/*** + * @module rspamd_fann + * This module enables [fann](http://libfann.github.io) interaction in rspamd + * Please note, that this module works merely if you have `ENABLE_FANN=ON` option + * definition when building rspamd + */ + +/* + * Fann functions + */ +LUA_FUNCTION_DEF (fann, is_enabled); +LUA_FUNCTION_DEF (fann, create); +LUA_FUNCTION_DEF (fann, load); + +/* + * Fann methods + */ +LUA_FUNCTION_DEF (fann, train); +LUA_FUNCTION_DEF (fann, test); +LUA_FUNCTION_DEF (fann, save); +LUA_FUNCTION_DEF (fann, get_inputs); +LUA_FUNCTION_DEF (fann, get_outputs); +LUA_FUNCTION_DEF (fann, dtor); + +static const struct luaL_reg fannlib_f[] = { + LUA_INTERFACE_DEF (fann, is_enabled), + LUA_INTERFACE_DEF (fann, create), + LUA_INTERFACE_DEF (fann, load), + {NULL, NULL} +}; + +static const struct luaL_reg fannlib_m[] = { + LUA_INTERFACE_DEF (fann, train), + LUA_INTERFACE_DEF (fann, test), + LUA_INTERFACE_DEF (fann, save), + LUA_INTERFACE_DEF (fann, get_inputs), + LUA_INTERFACE_DEF (fann, get_outputs), + {"__gc", lua_fann_dtor}, + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL} +}; + +#ifdef WITH_FANN +struct fann * +rspamd_lua_check_fann (lua_State *L, gint pos) +{ + void *ud = luaL_checkudata (L, pos, "rspamd{fann}"); + luaL_argcheck (L, ud != NULL, pos, "'fann' expected"); + return ud ? *((struct fann **) ud) : NULL; +} +#endif + +/*** + * @function rspamd_fann.is_enabled() + * Checks if fann is enabled for this rspamd build + * @return {boolean} true if fann is enabled + */ +static gint +lua_fann_is_enabled (lua_State *L) +{ +#ifdef WITH_FANN + lua_pushboolean (L, true); +#else + lua_pushboolean (L, false); +#endif + return 1; +} + +/*** + * @function rspamd_fann.create(nlayers, [layer1, ... layern]) + * Creates new neural network with `nlayers` that contains `layer1`...`layern` + * neurons in each layer + * @param {number} nlayers number of layers + * @param {number} layerI number of neurons in each layer + * @return {fann} fann object + */ +static gint +lua_fann_create (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f, **pfann; + guint nlayers, *layers, i; + + nlayers = luaL_checknumber (L, 1); + + if (nlayers > 0) { + layers = g_malloc (nlayers * sizeof (layers[0])); + + for (i = 0; i < nlayers; i ++) { + layers[i] = luaL_checknumber (L, i + 2); + } + + f = fann_create_standard_array (nlayers, layers); + + if (f != NULL) { + pfann = lua_newuserdata (L, sizeof (gpointer)); + *pfann = f; + rspamd_lua_setclass (L, "rspamd{fann}", -1); + } + else { + lua_pushnil (L); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @function rspamd_fann.load(file) + * Loads neural network from the file + * @param {string} file filename where fann is stored + * @return {fann} fann object + */ +static gint +lua_fann_load (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f, **pfann; + const gchar *fname; + + fname = luaL_checkstring (L, 1); + + if (fname != NULL) { + f = fann_create_from_file (fname); + + if (f != NULL) { + pfann = lua_newuserdata (L, sizeof (gpointer)); + *pfann = f; + rspamd_lua_setclass (L, "rspamd{fann}", -1); + } + else { + lua_pushnil (L); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + + +/** + * @method rspamd_fann:train(inputs, outputs) + * Trains neural network with samples. Inputs and outputs should be tables of + * equal size, each row in table should be N inputs and M outputs, e.g. + * {0, 1, 1} -> {0} + * {1, 0, 0} -> {1} + * @param {table/table} inputs input samples + * @param {table/table} outputs output samples + * @return {number} number of samples learned + */ +static gint +lua_fann_train (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + guint ninputs, noutputs, i, j, cur_len; + float *cur_input, *cur_output; + gint ret = 0; + + if (f != NULL) { + /* First check sanity, call for table.getn for that */ + ninputs = rspamd_lua_table_size (L, 2); + noutputs = rspamd_lua_table_size (L, 3); + + if (ninputs != noutputs) { + msg_err ("bad number of inputs(%d) and output(%d) args for train", + ninputs, noutputs); + } + else { + for (i = 0; i < ninputs; i ++) { + /* Push table with inputs */ + lua_rawgeti (L, 2, i + 1); + + cur_len = rspamd_lua_table_size (L, -1); + + if (cur_len != fann_get_num_input (f)) { + msg_err ( + "bad number of input samples: %d, %d expected", + cur_len, + fann_get_num_input (f)); + lua_pop (L, 1); + continue; + } + + cur_input = g_malloc (cur_len * sizeof (gint)); + + for (j = 0; j < cur_len; j ++) { + lua_rawgeti (L, -1, j + 1); + cur_input[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + lua_pop (L, 1); /* Inputs table */ + + /* Push table with outputs */ + lua_rawgeti (L, 3, i + 1); + + cur_len = rspamd_lua_table_size (L, -1); + + if (cur_len != fann_get_num_output (f)) { + msg_err ( + "bad number of output samples: %d, %d expected", + cur_len, + fann_get_num_output (f)); + lua_pop (L, 1); + g_free (cur_input); + continue; + } + + cur_output = g_malloc (cur_len * sizeof (gint)); + + for (j = 0; j < cur_len; j++) { + lua_rawgeti (L, -1, j + 1); + cur_output[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + lua_pop (L, 1); /* Outputs table */ + + fann_train (f, cur_input, cur_output); + g_free (cur_input); + g_free (cur_output); + ret ++; + } + } + } + + lua_pushnumber (L, ret); + + return 1; +#endif +} + +/** + * @method rspamd_fann:test(inputs) + * Tests neural network with samples. Inputs is a single sample of input data. + * The function returns table of results, e.g.: + * {0, 1, 1} -> {0} + * @param {table} inputs input sample + * @return {table/number} outputs values + */ +static gint +lua_fann_test (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + guint ninputs, noutputs, i; + float *cur_input, *cur_output; + + if (f != NULL) { + /* First check sanity, call for table.getn for that */ + ninputs = rspamd_lua_table_size (L, 2); + cur_input = g_malloc (ninputs * sizeof (gint)); + + for (i = 0; i < ninputs; i++) { + lua_rawgeti (L, 2, i + 1); + cur_input[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + cur_output = fann_run (f, cur_input); + noutputs = fann_get_num_output (f); + lua_createtable (L, noutputs, 0); + + for (i = 0; i < noutputs; i ++) { + lua_pushnumber (L, cur_output[i]); + lua_rawseti (L, -2, i + 1); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @method rspamd_fann:get_inputs() + * Returns number of inputs for neural network + * @return {number} number of inputs + */ +static gint +lua_fann_get_inputs (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + + if (f != NULL) { + lua_pushnumber (L, fann_get_num_input (f)); + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @method rspamd_fann:get_outputs() + * Returns number of outputs for neural network + * @return {number} number of outputs + */ +static gint +lua_fann_get_outputs (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + + if (f != NULL) { + lua_pushnumber (L, fann_get_num_output (f)); + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @method rspamd_fann:save(fname) + * Save fann to file named 'fname' + * @param {string} fname filename to save fann into + * @return {boolean} true if ann has been saved + */ +static gint +lua_fann_save (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + const gchar *fname = luaL_checkstring (L, 2); + + if (f != NULL && fname != NULL) { + if (fann_save (f, fname) == 0) { + lua_pushboolean (L, true); + } + else { + msg_err ("cannot save ANN to %s: %s", fname, strerror (errno)); + lua_pushboolean (L, false); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +static gint +lua_fann_dtor (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + + if (f) { + fann_destroy (f); + } + + return 0; +#endif +} + +static gint +lua_load_fann (lua_State * L) +{ + lua_newtable (L); + luaL_register (L, NULL, fannlib_f); + + return 1; +} + +void +luaopen_fann (lua_State * L) +{ + rspamd_lua_new_class (L, "rspamd{fann}", fannlib_m); + lua_pop (L, 1); + + rspamd_lua_add_preload (L, "rspamd_fann", lua_load_fann); +} |