From 853a6c50dbb9514a1757ec71aa979145944dad6d Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 4 Nov 2016 15:27:40 +0000 Subject: [PATCH] [Rework] Implement loading/invalidating --- src/plugins/lua/fann_scores.lua | 126 +++++++++++++++++++++++--------- 1 file changed, 90 insertions(+), 36 deletions(-) diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 0a238db29..3c46cda2f 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -55,9 +55,9 @@ local redis_lua_script_can_train = [[ if ret then nham = tonumber(ret) end if KEYS[3] == 'spam' then - if nham + 1 >= nspam then return tostring(nspam) end + if nham + 1 >= nspam then return tostring(nspam + 1) end else - if nspam + 1 >= nham then return tostring(nham) end + if nspam + 1 >= nham then return tostring(nham + 1) end end return tostring(0) @@ -80,12 +80,28 @@ local redis_lua_script_maybe_load = [[ return false ]] -local redis_fann_maybe_load_sha = nil +local redis_maybe_load_sha = nil + +-- Lua script to invalidate ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +local redis_lua_script_maybe_invalidate = [[ + local locked = redis.call('GET', KEYS[1] .. '_locked') + if locked then return false end + redis.call('SET', KEYS[1] .. '_locked', '1') + redis.call('SET', KEYS[1] .. '_version', '0') + redis.call('DEL', KEYS[1] .. '_spam') + redis.call('DEL', KEYS[1] .. '_ham') + redis.call('DEL', KEYS[1] .. '_data') + redis.call('DEL', KEYS[1] .. '_locked') + return 1 +]] +local redis_maybe_invalidate_sha = nil local redis_params redis_params = rspamd_parse_redis_server('fann_scores') -local fann_prefix = 'RF' +local fann_prefix = 'RFANN' local max_trains = 1000 local max_epoch = 100 local use_settings = false @@ -385,29 +401,25 @@ local function gen_fann_prefix(id) end end -local function is_fann_valid(id) - if data[id].fann then +local function is_fann_valid(ann) + if ann then local n = rspamd_config:get_symbols_count() + count_metatokens() - if n ~= data[id].fann:get_inputs() then + if n ~= ann:get_inputs() then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', data[id].fann:get_inputs(), n) - data[id].fann = nil + ' is found in the cache', ann:get_inputs(), n) + return false end - local layers = data[id].fann:get_layers() + local layers = ann:get_layers() if not layers or #layers ~= 5 then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s', #layers) - data[id].fann = nil + return false end - end - if data[id].fann then return true end - - return false end local function fann_scores_filter(task) @@ -448,6 +460,39 @@ local function create_train_fann(n, id) data[id].epoch = 0 end +local function load_or_invalidate_fann(data, id, ev_base) + local err,ann_data = rspamd_util.zstd_decompress(data) + local ann + + if err or not ann_data then + rspamd_logger.errx('cannot decompress ann: %s', err) + else + ann = rspamd_fann.load_data(ann_data) + end + + if is_fann_valid(ann) then + data[id].fann = ann + else + local function redis_invalidate_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err) + elseif type(data) == 'string' then + rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err) + end + end + -- Invalidate ANN + rspamd_logger.infox('invalidate ANN %s') + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + redis_invalidate_cb, --callback + 'EVALSHA', -- command + {redis_maybe_invalidate_sha, 1, fann_prefix .. id} + ) + end +end + local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base) local fname = gen_fann_prefix(id) @@ -468,7 +513,14 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, local k if learn_spam then k = 'spam' else k = 'ham' end + local function learn_vec_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err) + end + end + local function can_train_cb(err, data) + rspamd_logger.errx('data: %s, err: %s', data, err) if not err and tonumber(data) > 0 then local learn_data = symbols_to_fann_vector( map(function(r) return r[1] end, results), @@ -476,13 +528,13 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, ) -- Add filtered meta tokens each(function(e) table.insert(learn_data, e) end, extra) - local str = table.concat(learn_data, ';') + local str = rspamd_util.zstd_compress(table.concat(learn_data, ';')) redis_make_request(ev_base, rspamd_config, nil, true, -- is write - learn_cb, --callback + learn_vec_cb, --callback 'LPUSH', -- command {fname .. '_' .. k, str} -- arguments ) @@ -510,25 +562,11 @@ local function check_fanns(cfg, ev_base) rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) elseif type(data) == 'table' then each(function(i, elt) - local redis_load_cb = function(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err) - elseif type(data) == 'string' then - --load_fann(data, elt) - end - end local redis_update_cb = function(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err) - elseif data then - redis_make_request(ev_base, - rspamd_config, - nil, - false, -- is write - redis_load_cb, --callback - 'GET', -- command - {fann_prefix, fann_prefix .. elt .. '_data'} -- arguments - ) + elseif data and type(data) == 'string' then + load_or_invalidate_fann(data, elt, ev_base) end end @@ -545,14 +583,14 @@ local function check_fanns(cfg, ev_base) false, -- is write redis_update_cb, --callback 'EVALSHA', -- command - {redis_fann_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)} + {redis_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)} ) end, data) end end - if not redis_fann_maybe_load_sha then + if not redis_maybe_load_sha then -- Plan new event early return 1.0 end @@ -663,7 +701,7 @@ else if err or not data or type(data) ~= 'string' then rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err) else - redis_fann_maybe_load_sha = tostring(data) + redis_maybe_load_sha = tostring(data) rspamd_config:add_periodic(ev_base, 0.0, function(cfg, ev_base) @@ -679,5 +717,21 @@ else 'SCRIPT', -- command {'LOAD', redis_lua_script_maybe_load} -- arguments ) + + local function maybe_invalidate_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err) + else + redis_maybe_invalidate_sha = tostring(data) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + maybe_invalidate_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_maybe_invalidate} -- arguments + ) end) end -- 2.39.5