aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_fann.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2015-12-21 14:46:13 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2015-12-21 14:46:13 +0000
commit9a66359e75da95ccac88039d156776d90368641b (patch)
tree28d40fa5ff5c230a9bfc004413f772188413f7d6 /src/lua/lua_fann.c
parentc2ff857a70752a83e9d19184d140826fbed9f467 (diff)
downloadrspamd-9a66359e75da95ccac88039d156776d90368641b.tar.gz
rspamd-9a66359e75da95ccac88039d156776d90368641b.zip
Add lua_fann module
Diffstat (limited to 'src/lua/lua_fann.c')
-rw-r--r--src/lua/lua_fann.c435
1 files changed, 435 insertions, 0 deletions
diff --git a/src/lua/lua_fann.c b/src/lua/lua_fann.c
new file mode 100644
index 000000000..90c037d35
--- /dev/null
+++ b/src/lua/lua_fann.c
@@ -0,0 +1,435 @@
+/*
+ * Copyright (c) 2015, Vsevolod Stakhov
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "lua_common.h"
+
+#ifdef WITH_FANN
+#include <fann.h>
+#endif
+
+/***
+ * @module rspamd_fann
+ * This module enables [fann](http://libfann.github.io) interaction in rspamd
+ * Please note, that this module works merely if you have `ENABLE_FANN=ON` option
+ * definition when building rspamd
+ */
+
+/*
+ * Fann functions
+ */
+LUA_FUNCTION_DEF (fann, is_enabled);
+LUA_FUNCTION_DEF (fann, create);
+LUA_FUNCTION_DEF (fann, load);
+
+/*
+ * Fann methods
+ */
+LUA_FUNCTION_DEF (fann, train);
+LUA_FUNCTION_DEF (fann, test);
+LUA_FUNCTION_DEF (fann, save);
+LUA_FUNCTION_DEF (fann, get_inputs);
+LUA_FUNCTION_DEF (fann, get_outputs);
+LUA_FUNCTION_DEF (fann, dtor);
+
+static const struct luaL_reg fannlib_f[] = {
+ LUA_INTERFACE_DEF (fann, is_enabled),
+ LUA_INTERFACE_DEF (fann, create),
+ LUA_INTERFACE_DEF (fann, load),
+ {NULL, NULL}
+};
+
+static const struct luaL_reg fannlib_m[] = {
+ LUA_INTERFACE_DEF (fann, train),
+ LUA_INTERFACE_DEF (fann, test),
+ LUA_INTERFACE_DEF (fann, save),
+ LUA_INTERFACE_DEF (fann, get_inputs),
+ LUA_INTERFACE_DEF (fann, get_outputs),
+ {"__gc", lua_fann_dtor},
+ {"__tostring", rspamd_lua_class_tostring},
+ {NULL, NULL}
+};
+
+#ifdef WITH_FANN
+struct fann *
+rspamd_lua_check_fann (lua_State *L, gint pos)
+{
+ void *ud = luaL_checkudata (L, pos, "rspamd{fann}");
+ luaL_argcheck (L, ud != NULL, pos, "'fann' expected");
+ return ud ? *((struct fann **) ud) : NULL;
+}
+#endif
+
+/***
+ * @function rspamd_fann.is_enabled()
+ * Checks if fann is enabled for this rspamd build
+ * @return {boolean} true if fann is enabled
+ */
+static gint
+lua_fann_is_enabled (lua_State *L)
+{
+#ifdef WITH_FANN
+ lua_pushboolean (L, true);
+#else
+ lua_pushboolean (L, false);
+#endif
+ return 1;
+}
+
+/***
+ * @function rspamd_fann.create(nlayers, [layer1, ... layern])
+ * Creates new neural network with `nlayers` that contains `layer1`...`layern`
+ * neurons in each layer
+ * @param {number} nlayers number of layers
+ * @param {number} layerI number of neurons in each layer
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_create (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f, **pfann;
+ guint nlayers, *layers, i;
+
+ nlayers = luaL_checknumber (L, 1);
+
+ if (nlayers > 0) {
+ layers = g_malloc (nlayers * sizeof (layers[0]));
+
+ for (i = 0; i < nlayers; i ++) {
+ layers[i] = luaL_checknumber (L, i + 2);
+ }
+
+ f = fann_create_standard_array (nlayers, layers);
+
+ if (f != NULL) {
+ pfann = lua_newuserdata (L, sizeof (gpointer));
+ *pfann = f;
+ rspamd_lua_setclass (L, "rspamd{fann}", -1);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @function rspamd_fann.load(file)
+ * Loads neural network from the file
+ * @param {string} file filename where fann is stored
+ * @return {fann} fann object
+ */
+static gint
+lua_fann_load (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f, **pfann;
+ const gchar *fname;
+
+ fname = luaL_checkstring (L, 1);
+
+ if (fname != NULL) {
+ f = fann_create_from_file (fname);
+
+ if (f != NULL) {
+ pfann = lua_newuserdata (L, sizeof (gpointer));
+ *pfann = f;
+ rspamd_lua_setclass (L, "rspamd{fann}", -1);
+ }
+ else {
+ lua_pushnil (L);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+
+/**
+ * @method rspamd_fann:train(inputs, outputs)
+ * Trains neural network with samples. Inputs and outputs should be tables of
+ * equal size, each row in table should be N inputs and M outputs, e.g.
+ * {0, 1, 1} -> {0}
+ * {1, 0, 0} -> {1}
+ * @param {table/table} inputs input samples
+ * @param {table/table} outputs output samples
+ * @return {number} number of samples learned
+ */
+static gint
+lua_fann_train (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ guint ninputs, noutputs, i, j, cur_len;
+ float *cur_input, *cur_output;
+ gint ret = 0;
+
+ if (f != NULL) {
+ /* First check sanity, call for table.getn for that */
+ ninputs = rspamd_lua_table_size (L, 2);
+ noutputs = rspamd_lua_table_size (L, 3);
+
+ if (ninputs != noutputs) {
+ msg_err ("bad number of inputs(%d) and output(%d) args for train",
+ ninputs, noutputs);
+ }
+ else {
+ for (i = 0; i < ninputs; i ++) {
+ /* Push table with inputs */
+ lua_rawgeti (L, 2, i + 1);
+
+ cur_len = rspamd_lua_table_size (L, -1);
+
+ if (cur_len != fann_get_num_input (f)) {
+ msg_err (
+ "bad number of input samples: %d, %d expected",
+ cur_len,
+ fann_get_num_input (f));
+ lua_pop (L, 1);
+ continue;
+ }
+
+ cur_input = g_malloc (cur_len * sizeof (gint));
+
+ for (j = 0; j < cur_len; j ++) {
+ lua_rawgeti (L, -1, j + 1);
+ cur_input[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1); /* Inputs table */
+
+ /* Push table with outputs */
+ lua_rawgeti (L, 3, i + 1);
+
+ cur_len = rspamd_lua_table_size (L, -1);
+
+ if (cur_len != fann_get_num_output (f)) {
+ msg_err (
+ "bad number of output samples: %d, %d expected",
+ cur_len,
+ fann_get_num_output (f));
+ lua_pop (L, 1);
+ g_free (cur_input);
+ continue;
+ }
+
+ cur_output = g_malloc (cur_len * sizeof (gint));
+
+ for (j = 0; j < cur_len; j++) {
+ lua_rawgeti (L, -1, j + 1);
+ cur_output[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1); /* Outputs table */
+
+ fann_train (f, cur_input, cur_output);
+ g_free (cur_input);
+ g_free (cur_output);
+ ret ++;
+ }
+ }
+ }
+
+ lua_pushnumber (L, ret);
+
+ return 1;
+#endif
+}
+
+/**
+ * @method rspamd_fann:test(inputs)
+ * Tests neural network with samples. Inputs is a single sample of input data.
+ * The function returns table of results, e.g.:
+ * {0, 1, 1} -> {0}
+ * @param {table} inputs input sample
+ * @return {table/number} outputs values
+ */
+static gint
+lua_fann_test (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ guint ninputs, noutputs, i;
+ float *cur_input, *cur_output;
+
+ if (f != NULL) {
+ /* First check sanity, call for table.getn for that */
+ ninputs = rspamd_lua_table_size (L, 2);
+ cur_input = g_malloc (ninputs * sizeof (gint));
+
+ for (i = 0; i < ninputs; i++) {
+ lua_rawgeti (L, 2, i + 1);
+ cur_input[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ cur_output = fann_run (f, cur_input);
+ noutputs = fann_get_num_output (f);
+ lua_createtable (L, noutputs, 0);
+
+ for (i = 0; i < noutputs; i ++) {
+ lua_pushnumber (L, cur_output[i]);
+ lua_rawseti (L, -2, i + 1);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:get_inputs()
+ * Returns number of inputs for neural network
+ * @return {number} number of inputs
+ */
+static gint
+lua_fann_get_inputs (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+
+ if (f != NULL) {
+ lua_pushnumber (L, fann_get_num_input (f));
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:get_outputs()
+ * Returns number of outputs for neural network
+ * @return {number} number of outputs
+ */
+static gint
+lua_fann_get_outputs (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+
+ if (f != NULL) {
+ lua_pushnumber (L, fann_get_num_output (f));
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+/***
+ * @method rspamd_fann:save(fname)
+ * Save fann to file named 'fname'
+ * @param {string} fname filename to save fann into
+ * @return {boolean} true if ann has been saved
+ */
+static gint
+lua_fann_save (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+ const gchar *fname = luaL_checkstring (L, 2);
+
+ if (f != NULL && fname != NULL) {
+ if (fann_save (f, fname) == 0) {
+ lua_pushboolean (L, true);
+ }
+ else {
+ msg_err ("cannot save ANN to %s: %s", fname, strerror (errno));
+ lua_pushboolean (L, false);
+ }
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+#endif
+}
+
+static gint
+lua_fann_dtor (lua_State *L)
+{
+#ifndef WITH_FANN
+ return 0;
+#else
+ struct fann *f = rspamd_lua_check_fann (L, 1);
+
+ if (f) {
+ fann_destroy (f);
+ }
+
+ return 0;
+#endif
+}
+
+static gint
+lua_load_fann (lua_State * L)
+{
+ lua_newtable (L);
+ luaL_register (L, NULL, fannlib_f);
+
+ return 1;
+}
+
+void
+luaopen_fann (lua_State * L)
+{
+ rspamd_lua_new_class (L, "rspamd{fann}", fannlib_m);
+ lua_pop (L, 1);
+
+ rspamd_lua_add_preload (L, "rspamd_fann", lua_load_fann);
+}