diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2015-12-21 14:46:13 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2015-12-21 14:46:13 +0000 |
commit | 9a66359e75da95ccac88039d156776d90368641b (patch) | |
tree | 28d40fa5ff5c230a9bfc004413f772188413f7d6 /src/lua/lua_fann.c | |
parent | c2ff857a70752a83e9d19184d140826fbed9f467 (diff) | |
download | rspamd-9a66359e75da95ccac88039d156776d90368641b.tar.gz rspamd-9a66359e75da95ccac88039d156776d90368641b.zip |
Add lua_fann module
Diffstat (limited to 'src/lua/lua_fann.c')
-rw-r--r-- | src/lua/lua_fann.c | 435 |
1 files changed, 435 insertions, 0 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c new file mode 100644 index 000000000..90c037d35 --- /dev/null +++ b/src/lua/lua_fann.c @@ -0,0 +1,435 @@ +/* + * Copyright (c) 2015, Vsevolod Stakhov + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "lua_common.h" + +#ifdef WITH_FANN +#include <fann.h> +#endif + +/*** + * @module rspamd_fann + * This module enables [fann](http://libfann.github.io) interaction in rspamd + * Please note, that this module works merely if you have `ENABLE_FANN=ON` option + * definition when building rspamd + */ + +/* + * Fann functions + */ +LUA_FUNCTION_DEF (fann, is_enabled); +LUA_FUNCTION_DEF (fann, create); +LUA_FUNCTION_DEF (fann, load); + +/* + * Fann methods + */ +LUA_FUNCTION_DEF (fann, train); +LUA_FUNCTION_DEF (fann, test); +LUA_FUNCTION_DEF (fann, save); +LUA_FUNCTION_DEF (fann, get_inputs); +LUA_FUNCTION_DEF (fann, get_outputs); +LUA_FUNCTION_DEF (fann, dtor); + +static const struct luaL_reg fannlib_f[] = { + LUA_INTERFACE_DEF (fann, is_enabled), + LUA_INTERFACE_DEF (fann, create), + LUA_INTERFACE_DEF (fann, load), + {NULL, NULL} +}; + +static const struct luaL_reg fannlib_m[] = { + LUA_INTERFACE_DEF (fann, train), + LUA_INTERFACE_DEF (fann, test), + LUA_INTERFACE_DEF (fann, save), + LUA_INTERFACE_DEF (fann, get_inputs), + LUA_INTERFACE_DEF (fann, get_outputs), + {"__gc", lua_fann_dtor}, + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL} +}; + +#ifdef WITH_FANN +struct fann * +rspamd_lua_check_fann (lua_State *L, gint pos) +{ + void *ud = luaL_checkudata (L, pos, "rspamd{fann}"); + luaL_argcheck (L, ud != NULL, pos, "'fann' expected"); + return ud ? *((struct fann **) ud) : NULL; +} +#endif + +/*** + * @function rspamd_fann.is_enabled() + * Checks if fann is enabled for this rspamd build + * @return {boolean} true if fann is enabled + */ +static gint +lua_fann_is_enabled (lua_State *L) +{ +#ifdef WITH_FANN + lua_pushboolean (L, true); +#else + lua_pushboolean (L, false); +#endif + return 1; +} + +/*** + * @function rspamd_fann.create(nlayers, [layer1, ... layern]) + * Creates new neural network with `nlayers` that contains `layer1`...`layern` + * neurons in each layer + * @param {number} nlayers number of layers + * @param {number} layerI number of neurons in each layer + * @return {fann} fann object + */ +static gint +lua_fann_create (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f, **pfann; + guint nlayers, *layers, i; + + nlayers = luaL_checknumber (L, 1); + + if (nlayers > 0) { + layers = g_malloc (nlayers * sizeof (layers[0])); + + for (i = 0; i < nlayers; i ++) { + layers[i] = luaL_checknumber (L, i + 2); + } + + f = fann_create_standard_array (nlayers, layers); + + if (f != NULL) { + pfann = lua_newuserdata (L, sizeof (gpointer)); + *pfann = f; + rspamd_lua_setclass (L, "rspamd{fann}", -1); + } + else { + lua_pushnil (L); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @function rspamd_fann.load(file) + * Loads neural network from the file + * @param {string} file filename where fann is stored + * @return {fann} fann object + */ +static gint +lua_fann_load (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f, **pfann; + const gchar *fname; + + fname = luaL_checkstring (L, 1); + + if (fname != NULL) { + f = fann_create_from_file (fname); + + if (f != NULL) { + pfann = lua_newuserdata (L, sizeof (gpointer)); + *pfann = f; + rspamd_lua_setclass (L, "rspamd{fann}", -1); + } + else { + lua_pushnil (L); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + + +/** + * @method rspamd_fann:train(inputs, outputs) + * Trains neural network with samples. Inputs and outputs should be tables of + * equal size, each row in table should be N inputs and M outputs, e.g. + * {0, 1, 1} -> {0} + * {1, 0, 0} -> {1} + * @param {table/table} inputs input samples + * @param {table/table} outputs output samples + * @return {number} number of samples learned + */ +static gint +lua_fann_train (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + guint ninputs, noutputs, i, j, cur_len; + float *cur_input, *cur_output; + gint ret = 0; + + if (f != NULL) { + /* First check sanity, call for table.getn for that */ + ninputs = rspamd_lua_table_size (L, 2); + noutputs = rspamd_lua_table_size (L, 3); + + if (ninputs != noutputs) { + msg_err ("bad number of inputs(%d) and output(%d) args for train", + ninputs, noutputs); + } + else { + for (i = 0; i < ninputs; i ++) { + /* Push table with inputs */ + lua_rawgeti (L, 2, i + 1); + + cur_len = rspamd_lua_table_size (L, -1); + + if (cur_len != fann_get_num_input (f)) { + msg_err ( + "bad number of input samples: %d, %d expected", + cur_len, + fann_get_num_input (f)); + lua_pop (L, 1); + continue; + } + + cur_input = g_malloc (cur_len * sizeof (gint)); + + for (j = 0; j < cur_len; j ++) { + lua_rawgeti (L, -1, j + 1); + cur_input[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + lua_pop (L, 1); /* Inputs table */ + + /* Push table with outputs */ + lua_rawgeti (L, 3, i + 1); + + cur_len = rspamd_lua_table_size (L, -1); + + if (cur_len != fann_get_num_output (f)) { + msg_err ( + "bad number of output samples: %d, %d expected", + cur_len, + fann_get_num_output (f)); + lua_pop (L, 1); + g_free (cur_input); + continue; + } + + cur_output = g_malloc (cur_len * sizeof (gint)); + + for (j = 0; j < cur_len; j++) { + lua_rawgeti (L, -1, j + 1); + cur_output[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + lua_pop (L, 1); /* Outputs table */ + + fann_train (f, cur_input, cur_output); + g_free (cur_input); + g_free (cur_output); + ret ++; + } + } + } + + lua_pushnumber (L, ret); + + return 1; +#endif +} + +/** + * @method rspamd_fann:test(inputs) + * Tests neural network with samples. Inputs is a single sample of input data. + * The function returns table of results, e.g.: + * {0, 1, 1} -> {0} + * @param {table} inputs input sample + * @return {table/number} outputs values + */ +static gint +lua_fann_test (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + guint ninputs, noutputs, i; + float *cur_input, *cur_output; + + if (f != NULL) { + /* First check sanity, call for table.getn for that */ + ninputs = rspamd_lua_table_size (L, 2); + cur_input = g_malloc (ninputs * sizeof (gint)); + + for (i = 0; i < ninputs; i++) { + lua_rawgeti (L, 2, i + 1); + cur_input[i] = lua_tonumber (L, -1); + lua_pop (L, 1); + } + + cur_output = fann_run (f, cur_input); + noutputs = fann_get_num_output (f); + lua_createtable (L, noutputs, 0); + + for (i = 0; i < noutputs; i ++) { + lua_pushnumber (L, cur_output[i]); + lua_rawseti (L, -2, i + 1); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @method rspamd_fann:get_inputs() + * Returns number of inputs for neural network + * @return {number} number of inputs + */ +static gint +lua_fann_get_inputs (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + + if (f != NULL) { + lua_pushnumber (L, fann_get_num_input (f)); + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @method rspamd_fann:get_outputs() + * Returns number of outputs for neural network + * @return {number} number of outputs + */ +static gint +lua_fann_get_outputs (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + + if (f != NULL) { + lua_pushnumber (L, fann_get_num_output (f)); + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +/*** + * @method rspamd_fann:save(fname) + * Save fann to file named 'fname' + * @param {string} fname filename to save fann into + * @return {boolean} true if ann has been saved + */ +static gint +lua_fann_save (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + const gchar *fname = luaL_checkstring (L, 2); + + if (f != NULL && fname != NULL) { + if (fann_save (f, fname) == 0) { + lua_pushboolean (L, true); + } + else { + msg_err ("cannot save ANN to %s: %s", fname, strerror (errno)); + lua_pushboolean (L, false); + } + } + else { + lua_pushnil (L); + } + + return 1; +#endif +} + +static gint +lua_fann_dtor (lua_State *L) +{ +#ifndef WITH_FANN + return 0; +#else + struct fann *f = rspamd_lua_check_fann (L, 1); + + if (f) { + fann_destroy (f); + } + + return 0; +#endif +} + +static gint +lua_load_fann (lua_State * L) +{ + lua_newtable (L); + luaL_register (L, NULL, fannlib_f); + + return 1; +} + +void +luaopen_fann (lua_State * L) +{ + rspamd_lua_new_class (L, "rspamd{fann}", fannlib_m); + lua_pop (L, 1); + + rspamd_lua_add_preload (L, "rspamd_fann", lua_load_fann); +} |