aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-17 09:01:09 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-17 09:01:09 +0100
commit0d554c93d984693f66b931b113165b576267fae3 (patch)
treeb774baac8ad65aad1489ca63e501534bc29e3b23 /src
parentd6d98ff4be5f3d7e134fda04f708c4c06061375f (diff)
downloadrspamd-0d554c93d984693f66b931b113165b576267fae3.tar.gz
rspamd-0d554c93d984693f66b931b113165b576267fae3.zip
[Fix] Multiple fixes in torch based ANN plugins
- Fix ANNs load - Fix disabling torch - Remove normalisation as we have tanh on output
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/fann_redis.lua34
1 files changed, 22 insertions, 12 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index f07a84033..2751b5d79 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -127,7 +127,7 @@ local redis_lua_script_maybe_load = [[
return {redis.call('GET', KEYS[1] .. '_data'), ret}
end
- return false
+ return tonumber(ret)
]]
local redis_maybe_load_sha = nil
@@ -332,7 +332,7 @@ local function is_fann_valid(rule, prefix, ann)
local n = rspamd_config:get_symbols_count() +
meta_functions.rspamd_count_metatokens()
- if torch then
+ if use_torch then
return true
else
if n ~= ann:get_inputs() then
@@ -375,7 +375,7 @@ local function fann_scores_filter(task)
fun.each(function(e) table.insert(fann_data, e) end, mt)
local score
- if torch then
+ if use_torch then
local out = fanns[id].fann:forward(torch.Tensor(fann_data))
score = out[1]
else
@@ -387,10 +387,16 @@ local function fann_scores_filter(task)
rspamd_logger.infox(task, 'fann score: %s', symscore)
if score > 0 then
- local result = rspamd_util.normalize_prob(score / 2.0, 0)
+ local result = score
+ if not use_torch then
+ result = rspamd_util.normalize_prob(score / 2.0, 0)
+ end
task:insert_result(rule.symbol_spam, result, symscore, id)
else
- local result = rspamd_util.normalize_prob((-score) / 2.0, 0)
+ local result = -(score)
+ if not use_torch then
+ result = rspamd_util.normalize_prob(-(score) / 2.0, 0)
+ end
task:insert_result(rule.symbol_ham, result, symscore, id)
end
end
@@ -398,7 +404,7 @@ local function fann_scores_filter(task)
end
local function create_fann(n, nlayers)
- if torch then
+ if use_torch then
-- We ignore number of layers so far when using torch
local ann = nn.Sequential()
local nhidden = math.floor((n + 1) / 2)
@@ -464,7 +470,7 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
return
else
- if torch then
+ if use_torch then
ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
else
ann = rspamd_fann.load_data(ann_data)
@@ -647,7 +653,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
prefix, train_mse)
local ann_data
- if torch then
+ if use_torch then
local f = torch.MemoryFile()
f:writeObject(fanns[elt].fann_train)
ann_data = rspamd_util.zstd_compress(f:storage():string())
@@ -762,7 +768,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
{redis_locked_invalidate_sha, 1, prefix}
)
else
- if torch then
+ if use_torch then
-- For torch we do not need to mix samples as they would be flushed
local dataset = {}
fun.each(function(s)
@@ -996,8 +1002,12 @@ local function check_fanns(rule, _, ev_base)
elseif _data and type(_data) == 'table' then
load_or_invalidate_fann(rule, _data, elt, ev_base)
else
- rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s',
- type(_data), elt)
+ if type(_data) == 'number' then
+ -- no new version
+ else
+ rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
+ type(_data), elt)
+ end
end
end
@@ -1161,7 +1171,7 @@ else
for _,rule in pairs(settings.rules) do
rspamd_config:add_on_load(function(cfg, ev_base, worker)
load_scripts(cfg, ev_base, function(_, _)
- check_fanns(rule, cfg, ev_base)
+ return check_fanns(rule, cfg, ev_base)
end)
if worker:get_name() == 'controller' and worker:get_index() == 0 then