From e63df7971ade42e4c1eaadbab99cee116bcfeded Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 16 Dec 2017 18:03:52 +0000 Subject: [PATCH] [Minor] Rework fann_redis to use redis scripts framework --- src/plugins/lua/fann_redis.lua | 254 +++++++-------------------------- 1 file changed, 49 insertions(+), 205 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 2b9b06f28..3120d8b18 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -115,7 +115,7 @@ local redis_lua_script_can_train = [[ return tostring(0) ]] -local redis_can_train_sha = nil +local redis_can_train_id = nil -- Lua script to load ANN from redis -- Uses the following keys @@ -132,7 +132,7 @@ local redis_lua_script_maybe_load = [[ return tonumber(ret) or 0 ]] -local redis_maybe_load_sha = nil +local redis_maybe_load_id = nil -- Lua script to invalidate ANN from redis -- Uses the following keys @@ -149,7 +149,7 @@ local redis_lua_script_maybe_invalidate = [[ redis.call('DEL', KEYS[1] .. '_hostname') return 1 ]] -local redis_maybe_invalidate_sha = nil +local redis_maybe_invalidate_id = nil -- Lua script to invalidate ANN from redis -- Uses the following keys @@ -163,7 +163,7 @@ local redis_lua_script_locked_invalidate = [[ redis.call('DEL', KEYS[1] .. '_hostname') return 1 ]] -local redis_locked_invalidate_sha = nil +local redis_locked_invalidate_id = nil -- Lua script to invalidate ANN from redis -- Uses the following keys @@ -182,7 +182,7 @@ local redis_lua_script_maybe_lock = [[ redis.call('SET', KEYS[1] .. '_hostname', KEYS[4]) return 1 ]] -local redis_maybe_lock_sha = nil +local redis_maybe_lock_id = nil -- Lua script to save and unlock ANN in redis -- Uses the following keys @@ -200,119 +200,23 @@ local redis_lua_script_save_unlock = [[ redis.call('EXPIRE', KEYS[1] .. '_version', KEYS[3]) return 1 ]] -local redis_save_unlock_sha = nil +local redis_save_unlock_id = nil local redis_params -local function load_scripts(cfg, ev_base, on_load_cb) - local function can_train_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err) - else - redis_can_train_sha = tostring(data) - end - end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - can_train_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_can_train} -- arguments - ) - - local function maybe_load_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err) - else - redis_maybe_load_sha = tostring(data) - - if on_load_cb then - rspamd_config:add_periodic(ev_base, 0.0, - function(_cfg, _ev_base) - return on_load_cb(_cfg, _ev_base) - end) - end - end - end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - maybe_load_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_maybe_load} -- arguments - ) - - local function maybe_invalidate_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err) - else - redis_maybe_invalidate_sha = tostring(data) - end - end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - maybe_invalidate_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_maybe_invalidate} -- arguments - ) - - local function locked_invalidate_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis locked invalidate script: %s', err) - else - redis_locked_invalidate_sha = tostring(data) - end - end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - locked_invalidate_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_locked_invalidate} -- arguments - ) - - local function maybe_lock_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err) - else - redis_maybe_lock_sha = tostring(data) - end - end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - maybe_lock_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_maybe_lock} -- arguments - ) - - local function save_unlock_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis save script: %s', err) - else - redis_save_unlock_sha = tostring(data) - end - end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - redis_params, - nil, - true, -- is write - save_unlock_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_save_unlock} -- arguments - ) +local function load_scripts(params) + redis_can_train_id = rspamd_redis.add_redis_script(redis_lua_script_can_train, + params) + redis_maybe_load_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_load, + params) + redis_maybe_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_invalidate, + params) + redis_locked_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_locked_invalidate, + params) + redis_maybe_lock_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_lock, + params) + redis_save_unlock_id = rspamd_redis.add_redis_script(redis_lua_script_save_unlock, + params) end local function gen_fann_prefix(rule, id) @@ -490,9 +394,6 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) local function redis_invalidate_cb(_err, _data) if _err then rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) - if string.match(_err, 'NOSCRIPT') then - load_scripts(rspamd_config, ev_base, nil) - end elseif type(_data) == 'string' then rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) fanns[id].version = 0 @@ -500,15 +401,10 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) end -- Invalidate ANN rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix) - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_invalidate_cb, --callback - 'EVALSHA', -- command - {redis_maybe_invalidate_sha, 1, prefix} - ) + rspamd_redis.exec_redis_script(redis_maybe_invalidate_id, + {ev_base = ev_base, is_write = true}, + redis_invalidate_cb, + {prefix}) end end @@ -589,9 +485,6 @@ local function fann_train_callback(rule, task, score, required_score, id) else if err then rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err) - if string.match(err, 'NOSCRIPT') then - load_scripts(rspamd_config, task:get_ev_base(), nil) - end elseif tonumber(data) < 0 then rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s", fname, k, -tonumber(data)) @@ -599,15 +492,10 @@ local function fann_train_callback(rule, task, score, required_score, id) end end - rspamd_redis.rspamd_redis_make_request(task, - rule.redis, - nil, - true, -- is write - can_train_cb, --callback - 'EVALSHA', -- command - {redis_can_train_sha, '4', gen_fann_prefix(rule, nil), - suffix, k, tostring(train_opts.max_trains)} -- arguments - ) + rspamd_redis.exec_redis_script(redis_can_train_id, + {task = task, is_write = true}, + can_train_cb, + {gen_fann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)}) end end @@ -637,9 +525,6 @@ local function train_fann(rule, _, ev_base, elt, worker) 'DEL', -- command {prefix .. '_locked'} ) - if string.match(err, 'NOSCRIPT') then - load_scripts(rspamd_config, ev_base, nil) - end else rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix) end @@ -674,15 +559,10 @@ local function train_fann(rule, _, ev_base, elt, worker) fanns[elt].version = fanns[elt].version + 1 fanns[elt].fann = fanns[elt].fann_train fanns[elt].fann_train = nil - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_save_cb, --callback - 'EVALSHA', -- command - {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)} - ) + rspamd_redis.exec_redis_script(redis_save_unlock_id, + {ev_base = ev_base, is_write = true}, + redis_save_cb, + {prefix, tostring(ann_data), tostring(rule.ann_expire)}) end end @@ -711,15 +591,10 @@ local function train_fann(rule, _, ev_base, elt, worker) fanns[elt].version = fanns[elt].version + 1 fanns[elt].fann = fanns[elt].fann_train fanns[elt].fann_train = nil - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_save_cb, --callback - 'EVALSHA', -- command - {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)} - ) + rspamd_redis.exec_redis_script(redis_save_unlock_id, + {ev_base = ev_base, is_write = true}, + redis_save_cb, + {prefix, tostring(ann_data), tostring(rule.ann_expire)}) end end @@ -768,15 +643,10 @@ local function train_fann(rule, _, ev_base, elt, worker) end -- Invalidate ANN rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix) - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_invalidate_cb, --callback - 'EVALSHA', -- command - {redis_locked_invalidate_sha, 1, prefix} - ) + rspamd_redis.exec_redis_script(redis_locked_invalidate_id, + {ev_base = ev_base, is_write = true}, + redis_invalidate_cb, + {prefix}) else if use_torch then -- For torch we do not need to mix samples as they would be flushed @@ -874,9 +744,6 @@ local function train_fann(rule, _, ev_base, elt, worker) if err then rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', prefix, err) - if string.match(err, 'NOSCRIPT') then - load_scripts(rspamd_config, ev_base, nil) - end elseif type(data) == 'number' then -- Can train ANN rspamd_redis.redis_make_request_taskless(ev_base, @@ -926,16 +793,10 @@ local function train_fann(rule, _, ev_base, elt, worker) rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix) return end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - redis_lock_cb, --callback - 'EVALSHA', -- command - {redis_maybe_lock_sha, '4', prefix, tostring(os.time()), - tostring(rule.lock_expire), rspamd_util.get_hostname()} - ) + rspamd_redis.exec_redis_script(redis_maybe_lock_id, + {ev_base = ev_base, is_write = true}, + redis_lock_cb, + {prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()}) end local function maybe_train_fanns(rule, cfg, ev_base, worker) @@ -979,10 +840,6 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) end end - if not redis_maybe_load_sha then - -- Plan new event early - return 1.0 - end -- First we need to get all fanns stored in our Redis rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, @@ -1009,9 +866,6 @@ local function check_fanns(rule, _, ev_base) if _err then rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, _err) - if string.match(_err, 'NOSCRIPT') then - load_scripts(rspamd_config, ev_base, nil) - end elseif _data and type(_data) == 'table' then load_or_invalidate_fann(rule, _data, elt, ev_base) else @@ -1028,24 +882,15 @@ local function check_fanns(rule, _, ev_base) local_ver = fanns[elt].version end end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - false, -- is write - redis_update_cb, --callback - 'EVALSHA', -- command - {redis_maybe_load_sha, 2, gen_fann_prefix(rule, elt), tostring(local_ver)} - ) + rspamd_redis.exec_redis_script(redis_maybe_load_id, + {ev_base = ev_base, is_write = false}, + redis_update_cb, + {gen_fann_prefix(rule, elt), tostring(local_ver)}) end, data) end end - if not redis_maybe_load_sha then - -- Plan new event early - return 1.0 - end -- First we need to get all fanns stored in our Redis rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, @@ -1187,10 +1032,9 @@ else -- Add training scripts for _,rule in pairs(settings.rules) do + load_scripts(rule.redis) rspamd_config:add_on_load(function(cfg, ev_base, worker) - load_scripts(cfg, ev_base, function(_, _) - return check_fanns(rule, cfg, ev_base) - end) + check_fanns(rule, cfg, ev_base) if worker:get_name() == 'controller' and worker:get_index() == 0 then -- We also want to train neural nets when they have enough data -- 2.39.5