/* * Copyright (c) 2015, Vsevolod Stakhov * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include "lua_common.h" #ifdef WITH_FANN #include #endif /*** * @module rspamd_fann * This module enables [fann](http://libfann.github.io) interaction in rspamd * Please note, that this module works merely if you have `ENABLE_FANN=ON` option * definition when building rspamd */ /* * Fann functions */ LUA_FUNCTION_DEF (fann, is_enabled); LUA_FUNCTION_DEF (fann, create); LUA_FUNCTION_DEF (fann, load); /* * Fann methods */ LUA_FUNCTION_DEF (fann, train); LUA_FUNCTION_DEF (fann, test); LUA_FUNCTION_DEF (fann, save); LUA_FUNCTION_DEF (fann, get_inputs); LUA_FUNCTION_DEF (fann, get_outputs); LUA_FUNCTION_DEF (fann, dtor); static const struct luaL_reg fannlib_f[] = { LUA_INTERFACE_DEF (fann, is_enabled), LUA_INTERFACE_DEF (fann, create), LUA_INTERFACE_DEF (fann, load), {NULL, NULL} }; static const struct luaL_reg fannlib_m[] = { LUA_INTERFACE_DEF (fann, train), LUA_INTERFACE_DEF (fann, test), LUA_INTERFACE_DEF (fann, save), LUA_INTERFACE_DEF (fann, get_inputs), LUA_INTERFACE_DEF (fann, get_outputs), {"__gc", lua_fann_dtor}, {"__tostring", rspamd_lua_class_tostring}, {NULL, NULL} }; #ifdef WITH_FANN struct fann * rspamd_lua_check_fann (lua_State *L, gint pos) { void *ud = luaL_checkudata (L, pos, "rspamd{fann}"); luaL_argcheck (L, ud != NULL, pos, "'fann' expected"); return ud ? *((struct fann **) ud) : NULL; } #endif /*** * @function rspamd_fann.is_enabled() * Checks if fann is enabled for this rspamd build * @return {boolean} true if fann is enabled */ static gint lua_fann_is_enabled (lua_State *L) { #ifdef WITH_FANN lua_pushboolean (L, true); #else lua_pushboolean (L, false); #endif return 1; } /*** * @function rspamd_fann.create(nlayers, [layer1, ... layern]) * Creates new neural network with `nlayers` that contains `layer1`...`layern` * neurons in each layer * @param {number} nlayers number of layers * @param {number} layerI number of neurons in each layer * @return {fann} fann object */ static gint lua_fann_create (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f, **pfann; guint nlayers, *layers, i; nlayers = luaL_checknumber (L, 1); if (nlayers > 0) { layers = g_malloc (nlayers * sizeof (layers[0])); for (i = 0; i < nlayers; i ++) { layers[i] = luaL_checknumber (L, i + 2); } f = fann_create_standard_array (nlayers, layers); if (f != NULL) { pfann = lua_newuserdata (L, sizeof (gpointer)); *pfann = f; rspamd_lua_setclass (L, "rspamd{fann}", -1); } else { lua_pushnil (L); } } else { lua_pushnil (L); } return 1; #endif } /*** * @function rspamd_fann.load(file) * Loads neural network from the file * @param {string} file filename where fann is stored * @return {fann} fann object */ static gint lua_fann_load (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f, **pfann; const gchar *fname; fname = luaL_checkstring (L, 1); if (fname != NULL) { f = fann_create_from_file (fname); if (f != NULL) { pfann = lua_newuserdata (L, sizeof (gpointer)); *pfann = f; rspamd_lua_setclass (L, "rspamd{fann}", -1); } else { lua_pushnil (L); } } else { lua_pushnil (L); } return 1; #endif } /** * @method rspamd_fann:train(inputs, outputs) * Trains neural network with samples. Inputs and outputs should be tables of * equal size, each row in table should be N inputs and M outputs, e.g. * {0, 1, 1} -> {0} * {1, 0, 0} -> {1} * @param {table/table} inputs input samples * @param {table/table} outputs output samples * @return {number} number of samples learned */ static gint lua_fann_train (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); guint ninputs, noutputs, i, j, cur_len; float *cur_input, *cur_output; gint ret = 0; if (f != NULL) { /* First check sanity, call for table.getn for that */ ninputs = rspamd_lua_table_size (L, 2); noutputs = rspamd_lua_table_size (L, 3); if (ninputs != noutputs) { msg_err ("bad number of inputs(%d) and output(%d) args for train", ninputs, noutputs); } else { for (i = 0; i < ninputs; i ++) { /* Push table with inputs */ lua_rawgeti (L, 2, i + 1); cur_len = rspamd_lua_table_size (L, -1); if (cur_len != fann_get_num_input (f)) { msg_err ( "bad number of input samples: %d, %d expected", cur_len, fann_get_num_input (f)); lua_pop (L, 1); continue; } cur_input = g_malloc (cur_len * sizeof (gint)); for (j = 0; j < cur_len; j ++) { lua_rawgeti (L, -1, j + 1); cur_input[i] = lua_tonumber (L, -1); lua_pop (L, 1); } lua_pop (L, 1); /* Inputs table */ /* Push table with outputs */ lua_rawgeti (L, 3, i + 1); cur_len = rspamd_lua_table_size (L, -1); if (cur_len != fann_get_num_output (f)) { msg_err ( "bad number of output samples: %d, %d expected", cur_len, fann_get_num_output (f)); lua_pop (L, 1); g_free (cur_input); continue; } cur_output = g_malloc (cur_len * sizeof (gint)); for (j = 0; j < cur_len; j++) { lua_rawgeti (L, -1, j + 1); cur_output[i] = lua_tonumber (L, -1); lua_pop (L, 1); } lua_pop (L, 1); /* Outputs table */ fann_train (f, cur_input, cur_output); g_free (cur_input); g_free (cur_output); ret ++; } } } lua_pushnumber (L, ret); return 1; #endif } /** * @method rspamd_fann:test(inputs) * Tests neural network with samples. Inputs is a single sample of input data. * The function returns table of results, e.g.: * {0, 1, 1} -> {0} * @param {table} inputs input sample * @return {table/number} outputs values */ static gint lua_fann_test (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); guint ninputs, noutputs, i; float *cur_input, *cur_output; if (f != NULL) { /* First check sanity, call for table.getn for that */ ninputs = rspamd_lua_table_size (L, 2); cur_input = g_malloc (ninputs * sizeof (gint)); for (i = 0; i < ninputs; i++) { lua_rawgeti (L, 2, i + 1); cur_input[i] = lua_tonumber (L, -1); lua_pop (L, 1); } cur_output = fann_run (f, cur_input); noutputs = fann_get_num_output (f); lua_createtable (L, noutputs, 0); for (i = 0; i < noutputs; i ++) { lua_pushnumber (L, cur_output[i]); lua_rawseti (L, -2, i + 1); } } else { lua_pushnil (L); } return 1; #endif } /*** * @method rspamd_fann:get_inputs() * Returns number of inputs for neural network * @return {number} number of inputs */ static gint lua_fann_get_inputs (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); if (f != NULL) { lua_pushnumber (L, fann_get_num_input (f)); } else { lua_pushnil (L); } return 1; #endif } /*** * @method rspamd_fann:get_outputs() * Returns number of outputs for neural network * @return {number} number of outputs */ static gint lua_fann_get_outputs (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); if (f != NULL) { lua_pushnumber (L, fann_get_num_output (f)); } else { lua_pushnil (L); } return 1; #endif } /*** * @method rspamd_fann:save(fname) * Save fann to file named 'fname' * @param {string} fname filename to save fann into * @return {boolean} true if ann has been saved */ static gint lua_fann_save (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); const gchar *fname = luaL_checkstring (L, 2); if (f != NULL && fname != NULL) { if (fann_save (f, fname) == 0) { lua_pushboolean (L, true); } else { msg_err ("cannot save ANN to %s: %s", fname, strerror (errno)); lua_pushboolean (L, false); } } else { lua_pushnil (L); } return 1; #endif } static gint lua_fann_dtor (lua_State *L) { #ifndef WITH_FANN return 0; #else struct fann *f = rspamd_lua_check_fann (L, 1); if (f) { fann_destroy (f); } return 0; #endif } static gint lua_load_fann (lua_State * L) { lua_newtable (L); luaL_register (L, NULL, fannlib_f); return 1; } void luaopen_fann (lua_State * L) { rspamd_lua_new_class (L, "rspamd{fann}", fannlib_m); lua_pop (L, 1); rspamd_lua_add_preload (L, "rspamd_fann", lua_load_fann); }