diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-09-17 09:01:09 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-09-17 09:01:09 +0100 |
commit | 0d554c93d984693f66b931b113165b576267fae3 (patch) | |
tree | b774baac8ad65aad1489ca63e501534bc29e3b23 /src | |
parent | d6d98ff4be5f3d7e134fda04f708c4c06061375f (diff) | |
download | rspamd-0d554c93d984693f66b931b113165b576267fae3.tar.gz rspamd-0d554c93d984693f66b931b113165b576267fae3.zip |
[Fix] Multiple fixes in torch based ANN plugins
- Fix ANNs load
- Fix disabling torch
- Remove normalisation as we have tanh on output
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 34 |
1 files changed, 22 insertions, 12 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index f07a84033..2751b5d79 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -127,7 +127,7 @@ local redis_lua_script_maybe_load = [[ return {redis.call('GET', KEYS[1] .. '_data'), ret} end - return false + return tonumber(ret) ]] local redis_maybe_load_sha = nil @@ -332,7 +332,7 @@ local function is_fann_valid(rule, prefix, ann) local n = rspamd_config:get_symbols_count() + meta_functions.rspamd_count_metatokens() - if torch then + if use_torch then return true else if n ~= ann:get_inputs() then @@ -375,7 +375,7 @@ local function fann_scores_filter(task) fun.each(function(e) table.insert(fann_data, e) end, mt) local score - if torch then + if use_torch then local out = fanns[id].fann:forward(torch.Tensor(fann_data)) score = out[1] else @@ -387,10 +387,16 @@ local function fann_scores_filter(task) rspamd_logger.infox(task, 'fann score: %s', symscore) if score > 0 then - local result = rspamd_util.normalize_prob(score / 2.0, 0) + local result = score + if not use_torch then + result = rspamd_util.normalize_prob(score / 2.0, 0) + end task:insert_result(rule.symbol_spam, result, symscore, id) else - local result = rspamd_util.normalize_prob((-score) / 2.0, 0) + local result = -(score) + if not use_torch then + result = rspamd_util.normalize_prob(-(score) / 2.0, 0) + end task:insert_result(rule.symbol_ham, result, symscore, id) end end @@ -398,7 +404,7 @@ local function fann_scores_filter(task) end local function create_fann(n, nlayers) - if torch then + if use_torch then -- We ignore number of layers so far when using torch local ann = nn.Sequential() local nhidden = math.floor((n + 1) / 2) @@ -464,7 +470,7 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err) return else - if torch then + if use_torch then ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject() else ann = rspamd_fann.load_data(ann_data) @@ -647,7 +653,7 @@ local function train_fann(rule, _, ev_base, elt, worker) rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', prefix, train_mse) local ann_data - if torch then + if use_torch then local f = torch.MemoryFile() f:writeObject(fanns[elt].fann_train) ann_data = rspamd_util.zstd_compress(f:storage():string()) @@ -762,7 +768,7 @@ local function train_fann(rule, _, ev_base, elt, worker) {redis_locked_invalidate_sha, 1, prefix} ) else - if torch then + if use_torch then -- For torch we do not need to mix samples as they would be flushed local dataset = {} fun.each(function(s) @@ -996,8 +1002,12 @@ local function check_fanns(rule, _, ev_base) elseif _data and type(_data) == 'table' then load_or_invalidate_fann(rule, _data, elt, ev_base) else - rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s', - type(_data), elt) + if type(_data) == 'number' then + -- no new version + else + rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s', + type(_data), elt) + end end end @@ -1161,7 +1171,7 @@ else for _,rule in pairs(settings.rules) do rspamd_config:add_on_load(function(cfg, ev_base, worker) load_scripts(cfg, ev_base, function(_, _) - check_fanns(rule, cfg, ev_base) + return check_fanns(rule, cfg, ev_base) end) if worker:get_name() == 'controller' and worker:get_index() == 0 then |