From: Vsevolod Stakhov Date: Sat, 5 Nov 2016 20:48:53 +0000 (+0100) Subject: [Rework] Add redis storage feature to fann_redis X-Git-Tag: 1.4.0~110 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=5eb00bd1b49028f9f5591572108b4bb104e56ebb;p=rspamd.git [Rework] Add redis storage feature to fann_redis --- diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index f55454bf6..aabf465ce 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -30,9 +30,7 @@ local module_log_id = 0x100 -- ANNs indexed by settings id local data = { ['0'] = { - fann_mtime = 0, - ntrains = 0, - epoch = 0, + version = 0, } } @@ -108,6 +106,20 @@ local redis_lua_script_maybe_lock = [[ ]] local redis_maybe_lock_sha = nil +-- Lua script to save and unlock ANN in redis +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - compressed ANN +local redis_lua_script_save_unlock = [[ + redis.call('INCRBY', KEYS[1] .. '_version', '1') + redis.call('DEL', KEYS[1] .. '_spam') + redis.call('DEL', KEYS[1] .. '_ham') + redis.call('SET', KEYS[1] .. '_data', KEYS[2]) + redis.call('DEL', KEYS[1] .. '_locked') + return 1 +]] +local redis_save_unlock_sha = nil + local redis_params redis_params = rspamd_parse_redis_server('fann_redis') @@ -116,6 +128,7 @@ local max_trains = 1000 local max_epoch = 100 local use_settings = false local watch_interval = 60.0 +local mse = 0.0001 local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args) if not ev_base or not redis_params or not callback or not command then @@ -251,8 +264,7 @@ end local function create_train_fann(n, id) data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1) - data[id].ntrains = 0 - data[id].epoch = 0 + data[id].version = 0 end local function load_or_invalidate_fann(data, id, ev_base) @@ -361,6 +373,41 @@ local function train_fann(cfg, ev_base, elt) end end + local function redis_save_unlock_sha(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s', + fann_prefix .. elt, err) + end + end + + local function ann_trained(errcode, errmsg, train_mse) + if errcode ~= 0 then + rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', + fann_prefix .. elt, errmsg) + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_unlock_cb, --callback + 'DEL', -- command + {fann_prefix .. elt .. '_lock'} + ) + else + rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', + fann_prefix .. elt, train_mse) + local ann_data = rspamd_util.zstd_compress(data[elt].fann:data()) + data[elt].version = data[elt].version + 1 + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + redis_save_cb, --callback + 'EVALSHA', -- command + {redis_save_unlock_sha, '2', fann_prefix .. elt, ann_data} + ) + end + end + local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', @@ -375,8 +422,8 @@ local function train_fann(cfg, ev_base, elt) ) else -- Decompress and convert to numbers each training vector - ham_elts = map(function(i, elt) - local str = tostring(rspamd_util.zstd_decompress(elt)) + ham_elts = map(function(i, tok) + local str = tostring(rspamd_util.zstd_decompress(tok)) return map(tonumber, rspamd_str_split(str, ';')) end, data) @@ -384,17 +431,24 @@ local function train_fann(cfg, ev_base, elt) local inputs = {} local outputs = {} - each(function(i, elt) - table.insert(inputs, totable(elt)) + each(function(i, sample) + table.insert(inputs, totable(sample)) table.insert(outputs, 1.0) end, spam_elts) - each(function(i, elt) - table.insert(inputs, totable(elt)) + each(function(i, sample) + table.insert(inputs, totable(sample)) table.insert(outputs, -1.0) end, spam_elts) -- Now we can train fann + local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() + if not data[elt].fann then + -- Create fann if it does not exist + create_train_fann(n, elt) + end + data[elt].fann:train_threaded(inputs, outputs, ann_trained, ev_base, + {max_epochs = max_epoch, desired_mse = mse}) end end @@ -412,8 +466,8 @@ local function train_fann(cfg, ev_base, elt) ) else -- Decompress and convert to numbers each training vector - spam_elts = map(function(i, elt) - local str = tostring(rspamd_util.zstd_decompress(elt)) + spam_elts = map(function(i, tok) + local str = tostring(rspamd_util.zstd_decompress(tok)) return map(tonumber, rspamd_str_split(str, ';')) end, data) redis_make_request(ev_base, @@ -700,6 +754,22 @@ else {'LOAD', redis_lua_script_maybe_lock} -- arguments ) + local function save_unlock_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis save script: %s', err) + else + redis_save_unlock_sha = tostring(data) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + save_unlock_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_save_unlock} -- arguments + ) + if worker:get_name() == 'normal' then -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0,