From 0d554c93d984693f66b931b113165b576267fae3 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sun, 17 Sep 2017 09:01:09 +0100 Subject: [PATCH] [Fix] Multiple fixes in torch based ANN plugins - Fix ANNs load - Fix disabling torch - Remove normalisation as we have tanh on output --- src/plugins/lua/fann_redis.lua | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index f07a84033..2751b5d79 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -127,7 +127,7 @@ local redis_lua_script_maybe_load = [[ return {redis.call('GET', KEYS[1] .. '_data'), ret} end - return false + return tonumber(ret) ]] local redis_maybe_load_sha = nil @@ -332,7 +332,7 @@ local function is_fann_valid(rule, prefix, ann) local n = rspamd_config:get_symbols_count() + meta_functions.rspamd_count_metatokens() - if torch then + if use_torch then return true else if n ~= ann:get_inputs() then @@ -375,7 +375,7 @@ local function fann_scores_filter(task) fun.each(function(e) table.insert(fann_data, e) end, mt) local score - if torch then + if use_torch then local out = fanns[id].fann:forward(torch.Tensor(fann_data)) score = out[1] else @@ -387,10 +387,16 @@ local function fann_scores_filter(task) rspamd_logger.infox(task, 'fann score: %s', symscore) if score > 0 then - local result = rspamd_util.normalize_prob(score / 2.0, 0) + local result = score + if not use_torch then + result = rspamd_util.normalize_prob(score / 2.0, 0) + end task:insert_result(rule.symbol_spam, result, symscore, id) else - local result = rspamd_util.normalize_prob((-score) / 2.0, 0) + local result = -(score) + if not use_torch then + result = rspamd_util.normalize_prob(-(score) / 2.0, 0) + end task:insert_result(rule.symbol_ham, result, symscore, id) end end @@ -398,7 +404,7 @@ local function fann_scores_filter(task) end local function create_fann(n, nlayers) - if torch then + if use_torch then -- We ignore number of layers so far when using torch local ann = nn.Sequential() local nhidden = math.floor((n + 1) / 2) @@ -464,7 +470,7 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err) return else - if torch then + if use_torch then ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject() else ann = rspamd_fann.load_data(ann_data) @@ -647,7 +653,7 @@ local function train_fann(rule, _, ev_base, elt, worker) rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', prefix, train_mse) local ann_data - if torch then + if use_torch then local f = torch.MemoryFile() f:writeObject(fanns[elt].fann_train) ann_data = rspamd_util.zstd_compress(f:storage():string()) @@ -762,7 +768,7 @@ local function train_fann(rule, _, ev_base, elt, worker) {redis_locked_invalidate_sha, 1, prefix} ) else - if torch then + if use_torch then -- For torch we do not need to mix samples as they would be flushed local dataset = {} fun.each(function(s) @@ -996,8 +1002,12 @@ local function check_fanns(rule, _, ev_base) elseif _data and type(_data) == 'table' then load_or_invalidate_fann(rule, _data, elt, ev_base) else - rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s', - type(_data), elt) + if type(_data) == 'number' then + -- no new version + else + rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s', + type(_data), elt) + end end end @@ -1161,7 +1171,7 @@ else for _,rule in pairs(settings.rules) do rspamd_config:add_on_load(function(cfg, ev_base, worker) load_scripts(cfg, ev_base, function(_, _) - check_fanns(rule, cfg, ev_base) + return check_fanns(rule, cfg, ev_base) end) if worker:get_name() == 'controller' and worker:get_index() == 0 then -- 2.39.5