aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-02 20:02:50 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-02 20:02:50 +0100
commit905820d0d3f92b0861189441f03a908b0ea14a55 (patch)
treed33377d956fdc343a74a2f05a038e3acf22c8003
parent7d236456e6bc91fe10eaa597bc1dfe3d141a0ccd (diff)
downloadrspamd-905820d0d3f92b0861189441f03a908b0ea14a55.tar.gz
rspamd-905820d0d3f92b0861189441f03a908b0ea14a55.zip
[Rework] Start moving of fann redis to torch
-rw-r--r--src/lua/lua_config.c2
-rw-r--r--src/lua/lua_task.c17
-rw-r--r--src/plugins/lua/fann_redis.lua129
3 files changed, 108 insertions, 40 deletions
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c
index 21cc4bbe7..6dfd45e12 100644
--- a/src/lua/lua_config.c
+++ b/src/lua/lua_config.c
@@ -694,7 +694,7 @@ static const struct luaL_reg configlib_m[] = {
LUA_INTERFACE_DEF (config, add_example),
LUA_INTERFACE_DEF (config, set_peak_cb),
LUA_INTERFACE_DEF (config, get_cpu_flags),
- LUA_INTERFACE_DEF (config, get_cpu_flags),
+ LUA_INTERFACE_DEF (config, has_torch),
{"__tostring", rspamd_lua_class_tostring},
{"__newindex", lua_config_newindex},
{NULL, NULL}
diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c
index 61ef6d075..9d139044f 100644
--- a/src/lua/lua_task.c
+++ b/src/lua/lua_task.c
@@ -3134,6 +3134,7 @@ struct tokens_foreach_cbdata {
struct rspamd_task *task;
lua_State *L;
gint idx;
+ gboolean normalize;
};
static void
@@ -3150,7 +3151,12 @@ tokens_foreach_cb (gint id, const gchar *sym, gint flags, gpointer ud)
mres = cbd->task->result;
if (mres && (s = g_hash_table_lookup (mres->symbols, sym)) != NULL) {
- lua_pushnumber (cbd->L, tanh (s->score));
+ if (cbd->normalize) {
+ lua_pushnumber (cbd->L, tanh (s->score));
+ }
+ else {
+ lua_pushnumber (cbd->L, s->score);
+ }
}
else {
lua_pushnumber (cbd->L, 0.0);
@@ -3168,6 +3174,15 @@ lua_task_get_symbols_tokens (lua_State *L)
cbd.task = task;
cbd.L = L;
cbd.idx = 1;
+ cbd.normalize = TRUE;
+
+ if (lua_type (L, 2) == LUA_TBOOLEAN) {
+ cbd.normalize = lua_toboolean (L, 2);
+ }
+ else {
+ cbd.normalize = TRUE;
+ }
+
lua_createtable (L, rspamd_symbols_cache_symbols_count (task->cfg->cache), 0);
rspamd_symbols_cache_foreach (task->cfg->cache, tokens_foreach_cb, &cbd);
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index 531a740d9..b0cbdefab 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -27,6 +27,15 @@ local rspamd_util = require "rspamd_util"
local rspamd_redis = require "lua_redis"
local fun = require "fun"
local meta_functions = require "meta_functions"
+local use_torch = false
+local torch
+local nn
+
+if rspamd_config:has_torch() then
+ use_torch = true
+ torch = require "torch"
+ nn = require "nn"
+end
-- Module vars
local default_options = {
@@ -298,11 +307,15 @@ local function gen_fann_prefix(rule, id)
local cksum = rspamd_config:get_symbols_cksum():hex()
-- We also need to count metatokens:
local n = meta_functions.rspamd_count_metatokens()
+ local tprefix = ''
+ if use_torch then
+ tprefix = 't';
+ end
if id then
- return string.format('%s%s%d%s', rule.prefix, cksum, n, id),
+ return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id),
rule.prefix .. id
else
- return string.format('%s%s%d', rule.prefix, cksum, n), nil
+ return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
end
end
@@ -311,20 +324,36 @@ local function is_fann_valid(rule, prefix, ann)
local n = rspamd_config:get_symbols_count() +
meta_functions.rspamd_count_metatokens()
- if n ~= ann:get_inputs() then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', prefix, ann:get_inputs(), n)
- return false
- end
- local layers = ann:get_layers()
+ if torch then
+ local nlayers = #ann
+ if nlayers ~= rule.nlayers then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+ prefix, nlayers)
+ return false
+ end
- if not layers or #layers ~= rule.nlayers then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
- prefix, #layers)
- return false
- end
+ local inp = ann:get(1):nElement()
+ if n ~= inp then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', prefix, inp, n)
+ return false
+ end
+ else
+ if n ~= ann:get_inputs() then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+ ' is found in the cache', prefix, ann:get_inputs(), n)
+ return false
+ end
+ local layers = ann:get_layers()
- return true
+ if not layers or #layers ~= rule.nlayers then
+ rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+ prefix, #layers)
+ return false
+ end
+
+ return true
+ end
end
end
@@ -348,15 +377,23 @@ local function fann_scores_filter(task)
-- Add filtered meta tokens
fun.each(function(e) table.insert(fann_data, e) end, mt)
- local out = fanns[id].fann:test(fann_data)
- local symscore = string.format('%.3f', out[1])
+ local score
+ if torch then
+ local out = fanns[id].fann:forward(torch.Tensor(fann_data))
+ score = out[1]
+ else
+ local out = fanns[id].fann:test(fann_data)
+ score = out[1]
+ end
+
+ local symscore = string.format('%.3f', score)
rspamd_logger.infox(task, 'fann score: %s', symscore)
- if out[1] > 0 then
- local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
+ if score > 0 then
+ local result = rspamd_util.normalize_prob(score / 2.0, 0)
task:insert_result(rule.symbol_spam, result, symscore, id)
else
- local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+ local result = rspamd_util.normalize_prob((-score) / 2.0, 0)
task:insert_result(rule.symbol_ham, result, symscore, id)
end
end
@@ -364,14 +401,25 @@ local function fann_scores_filter(task)
end
local function create_fann(n, nlayers)
- local layers = {}
- local div = 1.0
- for _ = 1, nlayers - 1 do
- table.insert(layers, math.floor(n / div))
- div = div * 2
- end
- table.insert(layers, 1)
- return rspamd_fann.create(nlayers, layers)
+ if torch then
+ -- We ignore number of layers so far when using torch
+ local ann = nn.Sequential()
+ local nhidden = math.floor((n + 1) / 2)
+ ann:add(nn.Linear(n, nhidden))
+ ann:add(nn.PReLU())
+ ann:add(nn.Linear(nhidden, 1))
+
+ return ann
+ else
+ local layers = {}
+ local div = 1.0
+ for _ = 1, nlayers - 1 do
+ table.insert(layers, math.floor(n / div))
+ div = div * 2
+ end
+ table.insert(layers, 1)
+ return rspamd_fann.create(nlayers, layers)
+ end
end
local function create_train_fann(rule, n, id)
@@ -382,13 +430,7 @@ local function create_train_fann(rule, n, id)
end
-- Fix that for flexibe layers number
if fanns[id].fann then
- if n ~= fanns[id].fann:get_inputs() or --
- (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then
- rspamd_logger.infox(rspamd_config,
- 'recreate ANN %s as it has a wrong number of inputs, version %s',
- prefix,
- fanns[id].version)
-
+ if not is_fann_valid(rule, prefix, fanns[id].fann) then
fanns[id].fann_train = create_fann(n, rule.nlayers)
fanns[id].fann = nil
elseif fanns[id].version % rule.max_usages == 0 then
@@ -421,7 +463,11 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
return
else
- ann = rspamd_fann.load_data(ann_data)
+ if torch then
+ ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
+ else
+ ann = rspamd_fann.load_data(ann_data)
+ end
end
if is_fann_valid(rule, prefix, ann) then
@@ -582,7 +628,15 @@ local function train_fann(rule, _, ev_base, elt)
else
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
prefix, train_mse)
- local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+ local ann_data
+ if torch then
+ local f = torch.MemoryFile()
+ f:writeObject(fanns[elt].fann_train)
+ ann_data = rspamd_util.zstd_compress(f:storage():string())
+ else
+ ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+ end
+
fanns[elt].version = fanns[elt].version + 1
fanns[elt].fann = fanns[elt].fann_train
fanns[elt].fann_train = nil
@@ -636,8 +690,7 @@ local function train_fann(rule, _, ev_base, elt)
end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
-- Now we can train fann
- if not fanns[elt] or not fanns[elt].fann_train
- or n ~= fanns[elt].fann_train:get_inputs() then
+ if not fanns[elt] or not fanns[elt].fann_train then
-- Create fann if it does not exist
create_train_fann(rule, n, elt)
end