diff options
-rw-r--r-- | src/lua/lua_config.c | 2 | ||||
-rw-r--r-- | src/lua/lua_task.c | 17 | ||||
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 129 |
3 files changed, 108 insertions, 40 deletions
diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 21cc4bbe7..6dfd45e12 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -694,7 +694,7 @@ static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF (config, add_example), LUA_INTERFACE_DEF (config, set_peak_cb), LUA_INTERFACE_DEF (config, get_cpu_flags), - LUA_INTERFACE_DEF (config, get_cpu_flags), + LUA_INTERFACE_DEF (config, has_torch), {"__tostring", rspamd_lua_class_tostring}, {"__newindex", lua_config_newindex}, {NULL, NULL} diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 61ef6d075..9d139044f 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -3134,6 +3134,7 @@ struct tokens_foreach_cbdata { struct rspamd_task *task; lua_State *L; gint idx; + gboolean normalize; }; static void @@ -3150,7 +3151,12 @@ tokens_foreach_cb (gint id, const gchar *sym, gint flags, gpointer ud) mres = cbd->task->result; if (mres && (s = g_hash_table_lookup (mres->symbols, sym)) != NULL) { - lua_pushnumber (cbd->L, tanh (s->score)); + if (cbd->normalize) { + lua_pushnumber (cbd->L, tanh (s->score)); + } + else { + lua_pushnumber (cbd->L, s->score); + } } else { lua_pushnumber (cbd->L, 0.0); @@ -3168,6 +3174,15 @@ lua_task_get_symbols_tokens (lua_State *L) cbd.task = task; cbd.L = L; cbd.idx = 1; + cbd.normalize = TRUE; + + if (lua_type (L, 2) == LUA_TBOOLEAN) { + cbd.normalize = lua_toboolean (L, 2); + } + else { + cbd.normalize = TRUE; + } + lua_createtable (L, rspamd_symbols_cache_symbols_count (task->cfg->cache), 0); rspamd_symbols_cache_foreach (task->cfg->cache, tokens_foreach_cb, &cbd); diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 531a740d9..b0cbdefab 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -27,6 +27,15 @@ local rspamd_util = require "rspamd_util" local rspamd_redis = require "lua_redis" local fun = require "fun" local meta_functions = require "meta_functions" +local use_torch = false +local torch +local nn + +if rspamd_config:has_torch() then + use_torch = true + torch = require "torch" + nn = require "nn" +end -- Module vars local default_options = { @@ -298,11 +307,15 @@ local function gen_fann_prefix(rule, id) local cksum = rspamd_config:get_symbols_cksum():hex() -- We also need to count metatokens: local n = meta_functions.rspamd_count_metatokens() + local tprefix = '' + if use_torch then + tprefix = 't'; + end if id then - return string.format('%s%s%d%s', rule.prefix, cksum, n, id), + return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), rule.prefix .. id else - return string.format('%s%s%d', rule.prefix, cksum, n), nil + return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil end end @@ -311,20 +324,36 @@ local function is_fann_valid(rule, prefix, ann) local n = rspamd_config:get_symbols_count() + meta_functions.rspamd_count_metatokens() - if n ~= ann:get_inputs() then - rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', prefix, ann:get_inputs(), n) - return false - end - local layers = ann:get_layers() + if torch then + local nlayers = #ann + if nlayers ~= rule.nlayers then + rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s', + prefix, nlayers) + return false + end - if not layers or #layers ~= rule.nlayers then - rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s', - prefix, #layers) - return false - end + local inp = ann:get(1):nElement() + if n ~= inp then + rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', prefix, inp, n) + return false + end + else + if n ~= ann:get_inputs() then + rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', prefix, ann:get_inputs(), n) + return false + end + local layers = ann:get_layers() - return true + if not layers or #layers ~= rule.nlayers then + rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s', + prefix, #layers) + return false + end + + return true + end end end @@ -348,15 +377,23 @@ local function fann_scores_filter(task) -- Add filtered meta tokens fun.each(function(e) table.insert(fann_data, e) end, mt) - local out = fanns[id].fann:test(fann_data) - local symscore = string.format('%.3f', out[1]) + local score + if torch then + local out = fanns[id].fann:forward(torch.Tensor(fann_data)) + score = out[1] + else + local out = fanns[id].fann:test(fann_data) + score = out[1] + end + + local symscore = string.format('%.3f', score) rspamd_logger.infox(task, 'fann score: %s', symscore) - if out[1] > 0 then - local result = rspamd_util.normalize_prob(out[1] / 2.0, 0) + if score > 0 then + local result = rspamd_util.normalize_prob(score / 2.0, 0) task:insert_result(rule.symbol_spam, result, symscore, id) else - local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) + local result = rspamd_util.normalize_prob((-score) / 2.0, 0) task:insert_result(rule.symbol_ham, result, symscore, id) end end @@ -364,14 +401,25 @@ local function fann_scores_filter(task) end local function create_fann(n, nlayers) - local layers = {} - local div = 1.0 - for _ = 1, nlayers - 1 do - table.insert(layers, math.floor(n / div)) - div = div * 2 - end - table.insert(layers, 1) - return rspamd_fann.create(nlayers, layers) + if torch then + -- We ignore number of layers so far when using torch + local ann = nn.Sequential() + local nhidden = math.floor((n + 1) / 2) + ann:add(nn.Linear(n, nhidden)) + ann:add(nn.PReLU()) + ann:add(nn.Linear(nhidden, 1)) + + return ann + else + local layers = {} + local div = 1.0 + for _ = 1, nlayers - 1 do + table.insert(layers, math.floor(n / div)) + div = div * 2 + end + table.insert(layers, 1) + return rspamd_fann.create(nlayers, layers) + end end local function create_train_fann(rule, n, id) @@ -382,13 +430,7 @@ local function create_train_fann(rule, n, id) end -- Fix that for flexibe layers number if fanns[id].fann then - if n ~= fanns[id].fann:get_inputs() or -- - (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then - rspamd_logger.infox(rspamd_config, - 'recreate ANN %s as it has a wrong number of inputs, version %s', - prefix, - fanns[id].version) - + if not is_fann_valid(rule, prefix, fanns[id].fann) then fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].fann = nil elseif fanns[id].version % rule.max_usages == 0 then @@ -421,7 +463,11 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err) return else - ann = rspamd_fann.load_data(ann_data) + if torch then + ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject() + else + ann = rspamd_fann.load_data(ann_data) + end end if is_fann_valid(rule, prefix, ann) then @@ -582,7 +628,15 @@ local function train_fann(rule, _, ev_base, elt) else rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', prefix, train_mse) - local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data()) + local ann_data + if torch then + local f = torch.MemoryFile() + f:writeObject(fanns[elt].fann_train) + ann_data = rspamd_util.zstd_compress(f:storage():string()) + else + ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data()) + end + fanns[elt].version = fanns[elt].version + 1 fanns[elt].fann = fanns[elt].fann_train fanns[elt].fann_train = nil @@ -636,8 +690,7 @@ local function train_fann(rule, _, ev_base, elt) end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts))) -- Now we can train fann - if not fanns[elt] or not fanns[elt].fann_train - or n ~= fanns[elt].fann_train:get_inputs() then + if not fanns[elt] or not fanns[elt].fann_train then -- Create fann if it does not exist create_train_fann(rule, n, elt) end |