Przeglądaj źródła

Add lua_fann module

tags/1.1.0
Vsevolod Stakhov 8 lat temu
rodzic
commit
9a66359e75
6 zmienionych plików z 470 dodań i 7 usunięć
  1. 6
    4
      doc/Makefile
  2. 1
    1
      src/CMakeLists.txt
  3. 3
    2
      src/lua/CMakeLists.txt
  4. 19
    0
      src/lua/lua_common.c
  5. 6
    0
      src/lua/lua_common.h
  6. 435
    0
      src/lua/lua_fann.c

+ 6
- 4
doc/Makefile Wyświetl plik

@@ -8,15 +8,15 @@ all: man
man: rspamd.8 rspamc.1 rspamadm.1

rspamd.8: rspamd.8.md
$(PANDOC) -s -f markdown -t man -o rspamd.8 rspamd.8.md
$(PANDOC) -s -f markdown -t man -o rspamd.8 rspamd.8.md
rspamc.1: rspamc.1.md
$(PANDOC) -s -f markdown -t man -o rspamc.1 rspamc.1.md
rspamadm.1: rspamadm.1.md
$(PANDOC) -s -f markdown -t man -o rspamadm.1 rspamadm.1.md
lua-doc: lua_regexp lua_ip lua_config lua_task lua_ucl lua_http lua_trie \
lua_dns lua_redis lua_upstream lua_expression lua_mimepart lua_logger lua_url \
lua_tcp lua_mempool lua_html lua_util
lua_tcp lua_mempool lua_html lua_util lua_fann

lua_regexp: ../src/lua/lua_regexp.c
$(LUADOC) < ../src/lua/lua_regexp.c > markdown/lua/regexp.md
@@ -53,4 +53,6 @@ lua_mempool: ../src/lua/lua_mempool.c
lua_html: ../src/lua/lua_html.c
$(LUADOC) < ../src/lua/lua_html.c > markdown/lua/html.md
lua_util: ../src/lua/lua_util.c
$(LUADOC) < ../src/lua/lua_util.c > markdown/lua/util.md
$(LUADOC) < ../src/lua/lua_util.c > markdown/lua/util.md
lua_fann: ../src/lua/lua_fann.c
$(LUADOC) < ../src/lua/lua_fann.c > markdown/lua/fann.md

+ 1
- 1
src/CMakeLists.txt Wyświetl plik

@@ -90,7 +90,7 @@ SET(PLUGINSSRC plugins/surbl.c
plugins/fuzzy_check.c
plugins/spf.c
plugins/dkim_check.c
libserver/rspamd_control.c)
libserver/rspamd_control.c lua/lua_fann.c)

SET(MODULES_LIST surbl regexp chartable fuzzy_check spf dkim)
SET(WORKERS_LIST normal controller smtp_proxy fuzzy lua http_proxy)

+ 3
- 2
src/lua/CMakeLists.txt Wyświetl plik

@@ -23,6 +23,7 @@ SET(LUASRC ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_url.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_util.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_tcp.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c)
${CMAKE_CURRENT_SOURCE_DIR}/lua_html.c
${CMAKE_CURRENT_SOURCE_DIR}/lua_fann.c)

SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)

+ 19
- 0
src/lua/lua_common.c Wyświetl plik

@@ -247,6 +247,7 @@ rspamd_lua_init ()
luaopen_util (L);
luaopen_tcp (L);
luaopen_html (L);
luaopen_fann (L);

rspamd_lua_add_preload (L, "ucl", luaopen_ucl);

@@ -944,3 +945,21 @@ rspamd_lua_traceback (lua_State *L)

return 1;
}

guint
rspamd_lua_table_size (lua_State *L, gint tbl_pos)
{
guint tbl_size = 0;

if (!lua_istable (L, tbl_pos)) {
return 0;
}

#if LUA_VERSION_NUM >= 502
tbl_size = lua_rawlen (L, tbl_pos);
#else
tbl_size = lua_objlen (L, tbl_pos);
#endif

return tbl_size;
}

+ 6
- 0
src/lua/lua_common.h Wyświetl plik

@@ -222,6 +222,7 @@ void luaopen_text (lua_State *L);
void luaopen_util (lua_State * L);
void luaopen_tcp (lua_State * L);
void luaopen_html (lua_State * L);
void luaopen_fann (lua_State *L);

gint rspamd_lua_call_filter (const gchar *function, struct rspamd_task *task);
gint rspamd_lua_call_chain_filter (const gchar *function,
@@ -289,5 +290,10 @@ gboolean rspamd_lua_parse_table_arguments (lua_State *L, gint pos,


gint rspamd_lua_traceback (lua_State *L);

/**
* Returns size of table at position `tbl_pos`
*/
guint rspamd_lua_table_size (lua_State *L, gint tbl_pos);
#endif /* WITH_LUA */
#endif /* RSPAMD_LUA_H */

+ 435
- 0
src/lua/lua_fann.c Wyświetl plik

@@ -0,0 +1,435 @@
/*
* Copyright (c) 2015, Vsevolod Stakhov
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#include "lua_common.h"

#ifdef WITH_FANN
#include <fann.h>
#endif

/***
* @module rspamd_fann
* This module enables [fann](http://libfann.github.io) interaction in rspamd
* Please note, that this module works merely if you have `ENABLE_FANN=ON` option
* definition when building rspamd
*/

/*
* Fann functions
*/
LUA_FUNCTION_DEF (fann, is_enabled);
LUA_FUNCTION_DEF (fann, create);
LUA_FUNCTION_DEF (fann, load);

/*
* Fann methods
*/
LUA_FUNCTION_DEF (fann, train);
LUA_FUNCTION_DEF (fann, test);
LUA_FUNCTION_DEF (fann, save);
LUA_FUNCTION_DEF (fann, get_inputs);
LUA_FUNCTION_DEF (fann, get_outputs);
LUA_FUNCTION_DEF (fann, dtor);

static const struct luaL_reg fannlib_f[] = {
LUA_INTERFACE_DEF (fann, is_enabled),
LUA_INTERFACE_DEF (fann, create),
LUA_INTERFACE_DEF (fann, load),
{NULL, NULL}
};

static const struct luaL_reg fannlib_m[] = {
LUA_INTERFACE_DEF (fann, train),
LUA_INTERFACE_DEF (fann, test),
LUA_INTERFACE_DEF (fann, save),
LUA_INTERFACE_DEF (fann, get_inputs),
LUA_INTERFACE_DEF (fann, get_outputs),
{"__gc", lua_fann_dtor},
{"__tostring", rspamd_lua_class_tostring},
{NULL, NULL}
};

#ifdef WITH_FANN
struct fann *
rspamd_lua_check_fann (lua_State *L, gint pos)
{
void *ud = luaL_checkudata (L, pos, "rspamd{fann}");
luaL_argcheck (L, ud != NULL, pos, "'fann' expected");
return ud ? *((struct fann **) ud) : NULL;
}
#endif

/***
* @function rspamd_fann.is_enabled()
* Checks if fann is enabled for this rspamd build
* @return {boolean} true if fann is enabled
*/
static gint
lua_fann_is_enabled (lua_State *L)
{
#ifdef WITH_FANN
lua_pushboolean (L, true);
#else
lua_pushboolean (L, false);
#endif
return 1;
}

/***
* @function rspamd_fann.create(nlayers, [layer1, ... layern])
* Creates new neural network with `nlayers` that contains `layer1`...`layern`
* neurons in each layer
* @param {number} nlayers number of layers
* @param {number} layerI number of neurons in each layer
* @return {fann} fann object
*/
static gint
lua_fann_create (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f, **pfann;
guint nlayers, *layers, i;

nlayers = luaL_checknumber (L, 1);

if (nlayers > 0) {
layers = g_malloc (nlayers * sizeof (layers[0]));

for (i = 0; i < nlayers; i ++) {
layers[i] = luaL_checknumber (L, i + 2);
}

f = fann_create_standard_array (nlayers, layers);

if (f != NULL) {
pfann = lua_newuserdata (L, sizeof (gpointer));
*pfann = f;
rspamd_lua_setclass (L, "rspamd{fann}", -1);
}
else {
lua_pushnil (L);
}
}
else {
lua_pushnil (L);
}

return 1;
#endif
}

/***
* @function rspamd_fann.load(file)
* Loads neural network from the file
* @param {string} file filename where fann is stored
* @return {fann} fann object
*/
static gint
lua_fann_load (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f, **pfann;
const gchar *fname;

fname = luaL_checkstring (L, 1);

if (fname != NULL) {
f = fann_create_from_file (fname);

if (f != NULL) {
pfann = lua_newuserdata (L, sizeof (gpointer));
*pfann = f;
rspamd_lua_setclass (L, "rspamd{fann}", -1);
}
else {
lua_pushnil (L);
}
}
else {
lua_pushnil (L);
}

return 1;
#endif
}


/**
* @method rspamd_fann:train(inputs, outputs)
* Trains neural network with samples. Inputs and outputs should be tables of
* equal size, each row in table should be N inputs and M outputs, e.g.
* {0, 1, 1} -> {0}
* {1, 0, 0} -> {1}
* @param {table/table} inputs input samples
* @param {table/table} outputs output samples
* @return {number} number of samples learned
*/
static gint
lua_fann_train (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f = rspamd_lua_check_fann (L, 1);
guint ninputs, noutputs, i, j, cur_len;
float *cur_input, *cur_output;
gint ret = 0;

if (f != NULL) {
/* First check sanity, call for table.getn for that */
ninputs = rspamd_lua_table_size (L, 2);
noutputs = rspamd_lua_table_size (L, 3);

if (ninputs != noutputs) {
msg_err ("bad number of inputs(%d) and output(%d) args for train",
ninputs, noutputs);
}
else {
for (i = 0; i < ninputs; i ++) {
/* Push table with inputs */
lua_rawgeti (L, 2, i + 1);

cur_len = rspamd_lua_table_size (L, -1);

if (cur_len != fann_get_num_input (f)) {
msg_err (
"bad number of input samples: %d, %d expected",
cur_len,
fann_get_num_input (f));
lua_pop (L, 1);
continue;
}

cur_input = g_malloc (cur_len * sizeof (gint));

for (j = 0; j < cur_len; j ++) {
lua_rawgeti (L, -1, j + 1);
cur_input[i] = lua_tonumber (L, -1);
lua_pop (L, 1);
}

lua_pop (L, 1); /* Inputs table */

/* Push table with outputs */
lua_rawgeti (L, 3, i + 1);

cur_len = rspamd_lua_table_size (L, -1);

if (cur_len != fann_get_num_output (f)) {
msg_err (
"bad number of output samples: %d, %d expected",
cur_len,
fann_get_num_output (f));
lua_pop (L, 1);
g_free (cur_input);
continue;
}

cur_output = g_malloc (cur_len * sizeof (gint));

for (j = 0; j < cur_len; j++) {
lua_rawgeti (L, -1, j + 1);
cur_output[i] = lua_tonumber (L, -1);
lua_pop (L, 1);
}

lua_pop (L, 1); /* Outputs table */

fann_train (f, cur_input, cur_output);
g_free (cur_input);
g_free (cur_output);
ret ++;
}
}
}

lua_pushnumber (L, ret);

return 1;
#endif
}

/**
* @method rspamd_fann:test(inputs)
* Tests neural network with samples. Inputs is a single sample of input data.
* The function returns table of results, e.g.:
* {0, 1, 1} -> {0}
* @param {table} inputs input sample
* @return {table/number} outputs values
*/
static gint
lua_fann_test (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f = rspamd_lua_check_fann (L, 1);
guint ninputs, noutputs, i;
float *cur_input, *cur_output;

if (f != NULL) {
/* First check sanity, call for table.getn for that */
ninputs = rspamd_lua_table_size (L, 2);
cur_input = g_malloc (ninputs * sizeof (gint));

for (i = 0; i < ninputs; i++) {
lua_rawgeti (L, 2, i + 1);
cur_input[i] = lua_tonumber (L, -1);
lua_pop (L, 1);
}

cur_output = fann_run (f, cur_input);
noutputs = fann_get_num_output (f);
lua_createtable (L, noutputs, 0);

for (i = 0; i < noutputs; i ++) {
lua_pushnumber (L, cur_output[i]);
lua_rawseti (L, -2, i + 1);
}
}
else {
lua_pushnil (L);
}

return 1;
#endif
}

/***
* @method rspamd_fann:get_inputs()
* Returns number of inputs for neural network
* @return {number} number of inputs
*/
static gint
lua_fann_get_inputs (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f = rspamd_lua_check_fann (L, 1);

if (f != NULL) {
lua_pushnumber (L, fann_get_num_input (f));
}
else {
lua_pushnil (L);
}

return 1;
#endif
}

/***
* @method rspamd_fann:get_outputs()
* Returns number of outputs for neural network
* @return {number} number of outputs
*/
static gint
lua_fann_get_outputs (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f = rspamd_lua_check_fann (L, 1);

if (f != NULL) {
lua_pushnumber (L, fann_get_num_output (f));
}
else {
lua_pushnil (L);
}

return 1;
#endif
}

/***
* @method rspamd_fann:save(fname)
* Save fann to file named 'fname'
* @param {string} fname filename to save fann into
* @return {boolean} true if ann has been saved
*/
static gint
lua_fann_save (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f = rspamd_lua_check_fann (L, 1);
const gchar *fname = luaL_checkstring (L, 2);

if (f != NULL && fname != NULL) {
if (fann_save (f, fname) == 0) {
lua_pushboolean (L, true);
}
else {
msg_err ("cannot save ANN to %s: %s", fname, strerror (errno));
lua_pushboolean (L, false);
}
}
else {
lua_pushnil (L);
}

return 1;
#endif
}

static gint
lua_fann_dtor (lua_State *L)
{
#ifndef WITH_FANN
return 0;
#else
struct fann *f = rspamd_lua_check_fann (L, 1);

if (f) {
fann_destroy (f);
}

return 0;
#endif
}

static gint
lua_load_fann (lua_State * L)
{
lua_newtable (L);
luaL_register (L, NULL, fannlib_f);

return 1;
}

void
luaopen_fann (lua_State * L)
{
rspamd_lua_new_class (L, "rspamd{fann}", fannlib_m);
lua_pop (L, 1);

rspamd_lua_add_preload (L, "rspamd_fann", lua_load_fann);
}

Ładowanie…
Anuluj
Zapisz