From 4b199ace126a7d50e8ed8a5d533ea70edc3d5d07 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 15 Nov 2016 14:13:00 +0000 Subject: [PATCH] [Fix] Multiple issues in fann_redis --- src/plugins/lua/fann_redis.lua | 78 ++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index c55f376de..361d82303 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -28,7 +28,7 @@ local ucl = require "ucl" local module_log_id = 0x200 -- Module vars -- ANNs indexed by settings id -local data = { +local fanns = { ['0'] = { version = 0, } @@ -80,7 +80,7 @@ local redis_lua_script_maybe_load = [[ local ver = 0 local ret = redis.call('GET', KEYS[1] .. '_version') if ret then ver = tonumber(ret) end - if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end + if ver > tonumber(KEYS[2]) then return redis.call('GET', KEYS[1] .. '_ann') end return false ]] @@ -135,6 +135,7 @@ local max_epoch = 100 local use_settings = false local watch_interval = 60.0 local mse = 0.0001 +local nlayers = 4 local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args) if not ev_base or not redis_params or not callback or not command then @@ -222,7 +223,7 @@ local function is_fann_valid(ann) end local layers = ann:get_layers() - if not layers or #layers ~= 5 then + if not layers or #layers ~= nlayers then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s', #layers) return false @@ -241,7 +242,7 @@ local function fann_scores_filter(task) end end - if data[id].fann then + if fanns[id].fann then local symbols,scores = task:get_symbols_numeric() local fann_data = symbols_to_fann_vector(symbols, scores) local mt = rspamd_gen_metatokens(task) @@ -250,7 +251,7 @@ local function fann_scores_filter(task) table.insert(fann_data, tok) end - local out = data[id].fann:test(fann_data) + local out = fanns[id].fann:test(fann_data) local symscore = string.format('%.3f', out[1]) rspamd_logger.infox(task, 'fann score: %s', symscore) @@ -265,8 +266,18 @@ local function fann_scores_filter(task) end local function create_train_fann(n, id) - data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1) - data[id].version = 0 + id = tostring(id) + if not fanns[id] then + fanns[id] = {} + end + + if fanns[id].fann then + fanns[id].fann_train = fanns[id].fann + fanns[id].fann = nil + else + fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1) + fanns[id].version = 0 + end end local function load_or_invalidate_fann(data, id, ev_base) @@ -280,7 +291,7 @@ local function load_or_invalidate_fann(data, id, ev_base) end if is_fann_valid(ann) then - data[id].fann = ann + fanns[id].fann = ann else local function redis_invalidate_cb(err, data) if err then @@ -367,6 +378,7 @@ end local function train_fann(cfg, ev_base, elt) local spam_elts = {} local ham_elts = {} + elt = tostring(elt) local function redis_unlock_cb(err, data) if err then @@ -398,7 +410,9 @@ local function train_fann(cfg, ev_base, elt) rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', fann_prefix .. elt, train_mse) local ann_data = rspamd_util.zstd_compress(data[elt].fann:data()) - data[elt].version = data[elt].version + 1 + fanns[elt].version = fanns[elt].version + 1 + fanns[elt].fann = fanns[elt].fann_train + fanns[elt].fann_train = nil redis_make_request(ev_base, rspamd_config, nil, @@ -424,32 +438,32 @@ local function train_fann(cfg, ev_base, elt) ) else -- Decompress and convert to numbers each training vector - ham_elts = map(function(i, tok) - local str = tostring(rspamd_util.zstd_decompress(tok)) - return map(tonumber, rspamd_str_split(str, ';')) + ham_elts = map(function(tok) + local _,str = rspamd_util.zstd_decompress(tok) + return map(tonumber, rspamd_str_split(tostring(str), ';')) end, data) -- Now we need to join inputs and create the appropriate test vectors local inputs = {} local outputs = {} - each(function(i, sample) + each(function(sample) table.insert(inputs, totable(sample)) - table.insert(outputs, 1.0) + table.insert(outputs, {1.0}) end, spam_elts) - each(function(i, sample) + each(function(sample) table.insert(inputs, totable(sample)) - table.insert(outputs, -1.0) - end, spam_elts) + table.insert(outputs, {-1.0}) + end, ham_elts) -- Now we can train fann local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() - if not data[elt].fann then + if not fanns[elt] or not fanns[elt].fann_train then -- Create fann if it does not exist create_train_fann(n, elt) end - data[elt].fann:train_threaded(inputs, outputs, ann_trained, ev_base, + fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base, {max_epochs = max_epoch, desired_mse = mse}) end end @@ -468,9 +482,9 @@ local function train_fann(cfg, ev_base, elt) ) else -- Decompress and convert to numbers each training vector - spam_elts = map(function(i, tok) - local str = tostring(rspamd_util.zstd_decompress(tok)) - return map(tonumber, rspamd_str_split(str, ';')) + spam_elts = map(function(tok) + local _,str = rspamd_util.zstd_decompress(tok) + return map(tonumber, rspamd_str_split(tostring(str), ';')) end, data) redis_make_request(ev_base, rspamd_config, @@ -514,7 +528,8 @@ local function maybe_train_fanns(cfg, ev_base) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) elseif type(data) == 'table' then - each(function(i, elt) + each(function(elt) + elt = tostring(elt) local redis_len_cb = function(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err) @@ -527,9 +542,9 @@ local function maybe_train_fanns(cfg, ev_base) local local_ver = 0 local numelt = tonumber(elt) - if data[numelt] then - if data[numelt].version then - local_ver = data[numelt].version + if fanns[numelt] then + if fanns[numelt].version then + local_ver = fanns[numelt].version end end redis_make_request(ev_base, @@ -567,7 +582,8 @@ local function check_fanns(cfg, ev_base) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) elseif type(data) == 'table' then - each(function(i, elt) + each(function(elt) + elt = tostring(elt) local redis_update_cb = function(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err) @@ -578,9 +594,9 @@ local function check_fanns(cfg, ev_base) local local_ver = 0 local numelt = tonumber(elt) - if data[numelt] then - if data[numelt].version then - local_ver = data[numelt].version + if fanns[numelt] then + if fanns[numelt].version then + local_ver = fanns[numelt].version end end redis_make_request(ev_base, @@ -683,7 +699,7 @@ else end end) -- This is needed to pass extra tokens from worker to log_helper - rspamd_plugins["fann_score"] = { + rspamd_plugins["fann_redis"] = { log_callback = function(task) return totable(map( function(tok) return {module_log_id, tok} end, -- 2.39.5