summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-15 14:13:00 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-11-15 14:13:00 +0000
commit4b199ace126a7d50e8ed8a5d533ea70edc3d5d07 (patch)
tree43d1153a5b287c677484ec2dbbd620be9a824d7e /src
parent56d39f7e6210a784f739171d8ac0de9e0bd97201 (diff)
downloadrspamd-4b199ace126a7d50e8ed8a5d533ea70edc3d5d07.tar.gz
rspamd-4b199ace126a7d50e8ed8a5d533ea70edc3d5d07.zip
[Fix] Multiple issues in fann_redis
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/fann_redis.lua78
1 files changed, 47 insertions, 31 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index c55f376de..361d82303 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -28,7 +28,7 @@ local ucl = require "ucl"
local module_log_id = 0x200
-- Module vars
-- ANNs indexed by settings id
-local data = {
+local fanns = {
['0'] = {
version = 0,
}
@@ -80,7 +80,7 @@ local redis_lua_script_maybe_load = [[
local ver = 0
local ret = redis.call('GET', KEYS[1] .. '_version')
if ret then ver = tonumber(ret) end
- if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end
+ if ver > tonumber(KEYS[2]) then return redis.call('GET', KEYS[1] .. '_ann') end
return false
]]
@@ -135,6 +135,7 @@ local max_epoch = 100
local use_settings = false
local watch_interval = 60.0
local mse = 0.0001
+local nlayers = 4
local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
if not ev_base or not redis_params or not callback or not command then
@@ -222,7 +223,7 @@ local function is_fann_valid(ann)
end
local layers = ann:get_layers()
- if not layers or #layers ~= 5 then
+ if not layers or #layers ~= nlayers then
rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
#layers)
return false
@@ -241,7 +242,7 @@ local function fann_scores_filter(task)
end
end
- if data[id].fann then
+ if fanns[id].fann then
local symbols,scores = task:get_symbols_numeric()
local fann_data = symbols_to_fann_vector(symbols, scores)
local mt = rspamd_gen_metatokens(task)
@@ -250,7 +251,7 @@ local function fann_scores_filter(task)
table.insert(fann_data, tok)
end
- local out = data[id].fann:test(fann_data)
+ local out = fanns[id].fann:test(fann_data)
local symscore = string.format('%.3f', out[1])
rspamd_logger.infox(task, 'fann score: %s', symscore)
@@ -265,8 +266,18 @@ local function fann_scores_filter(task)
end
local function create_train_fann(n, id)
- data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
- data[id].version = 0
+ id = tostring(id)
+ if not fanns[id] then
+ fanns[id] = {}
+ end
+
+ if fanns[id].fann then
+ fanns[id].fann_train = fanns[id].fann
+ fanns[id].fann = nil
+ else
+ fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+ fanns[id].version = 0
+ end
end
local function load_or_invalidate_fann(data, id, ev_base)
@@ -280,7 +291,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
end
if is_fann_valid(ann) then
- data[id].fann = ann
+ fanns[id].fann = ann
else
local function redis_invalidate_cb(err, data)
if err then
@@ -367,6 +378,7 @@ end
local function train_fann(cfg, ev_base, elt)
local spam_elts = {}
local ham_elts = {}
+ elt = tostring(elt)
local function redis_unlock_cb(err, data)
if err then
@@ -398,7 +410,9 @@ local function train_fann(cfg, ev_base, elt)
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
fann_prefix .. elt, train_mse)
local ann_data = rspamd_util.zstd_compress(data[elt].fann:data())
- data[elt].version = data[elt].version + 1
+ fanns[elt].version = fanns[elt].version + 1
+ fanns[elt].fann = fanns[elt].fann_train
+ fanns[elt].fann_train = nil
redis_make_request(ev_base,
rspamd_config,
nil,
@@ -424,32 +438,32 @@ local function train_fann(cfg, ev_base, elt)
)
else
-- Decompress and convert to numbers each training vector
- ham_elts = map(function(i, tok)
- local str = tostring(rspamd_util.zstd_decompress(tok))
- return map(tonumber, rspamd_str_split(str, ';'))
+ ham_elts = map(function(tok)
+ local _,str = rspamd_util.zstd_decompress(tok)
+ return map(tonumber, rspamd_str_split(tostring(str), ';'))
end, data)
-- Now we need to join inputs and create the appropriate test vectors
local inputs = {}
local outputs = {}
- each(function(i, sample)
+ each(function(sample)
table.insert(inputs, totable(sample))
- table.insert(outputs, 1.0)
+ table.insert(outputs, {1.0})
end, spam_elts)
- each(function(i, sample)
+ each(function(sample)
table.insert(inputs, totable(sample))
- table.insert(outputs, -1.0)
- end, spam_elts)
+ table.insert(outputs, {-1.0})
+ end, ham_elts)
-- Now we can train fann
local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
- if not data[elt].fann then
+ if not fanns[elt] or not fanns[elt].fann_train then
-- Create fann if it does not exist
create_train_fann(n, elt)
end
- data[elt].fann:train_threaded(inputs, outputs, ann_trained, ev_base,
+ fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
{max_epochs = max_epoch, desired_mse = mse})
end
end
@@ -468,9 +482,9 @@ local function train_fann(cfg, ev_base, elt)
)
else
-- Decompress and convert to numbers each training vector
- spam_elts = map(function(i, tok)
- local str = tostring(rspamd_util.zstd_decompress(tok))
- return map(tonumber, rspamd_str_split(str, ';'))
+ spam_elts = map(function(tok)
+ local _,str = rspamd_util.zstd_decompress(tok)
+ return map(tonumber, rspamd_str_split(tostring(str), ';'))
end, data)
redis_make_request(ev_base,
rspamd_config,
@@ -514,7 +528,8 @@ local function maybe_train_fanns(cfg, ev_base)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
elseif type(data) == 'table' then
- each(function(i, elt)
+ each(function(elt)
+ elt = tostring(elt)
local redis_len_cb = function(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
@@ -527,9 +542,9 @@ local function maybe_train_fanns(cfg, ev_base)
local local_ver = 0
local numelt = tonumber(elt)
- if data[numelt] then
- if data[numelt].version then
- local_ver = data[numelt].version
+ if fanns[numelt] then
+ if fanns[numelt].version then
+ local_ver = fanns[numelt].version
end
end
redis_make_request(ev_base,
@@ -567,7 +582,8 @@ local function check_fanns(cfg, ev_base)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
elseif type(data) == 'table' then
- each(function(i, elt)
+ each(function(elt)
+ elt = tostring(elt)
local redis_update_cb = function(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
@@ -578,9 +594,9 @@ local function check_fanns(cfg, ev_base)
local local_ver = 0
local numelt = tonumber(elt)
- if data[numelt] then
- if data[numelt].version then
- local_ver = data[numelt].version
+ if fanns[numelt] then
+ if fanns[numelt].version then
+ local_ver = fanns[numelt].version
end
end
redis_make_request(ev_base,
@@ -683,7 +699,7 @@ else
end
end)
-- This is needed to pass extra tokens from worker to log_helper
- rspamd_plugins["fann_score"] = {
+ rspamd_plugins["fann_redis"] = {
log_callback = function(task)
return totable(map(
function(tok) return {module_log_id, tok} end,