|
|
@@ -115,7 +115,7 @@ local redis_lua_script_can_train = [[ |
|
|
|
|
|
|
|
return tostring(0) |
|
|
|
]] |
|
|
|
local redis_can_train_sha = nil |
|
|
|
local redis_can_train_id = nil |
|
|
|
|
|
|
|
-- Lua script to load ANN from redis |
|
|
|
-- Uses the following keys |
|
|
@@ -132,7 +132,7 @@ local redis_lua_script_maybe_load = [[ |
|
|
|
|
|
|
|
return tonumber(ret) or 0 |
|
|
|
]] |
|
|
|
local redis_maybe_load_sha = nil |
|
|
|
local redis_maybe_load_id = nil |
|
|
|
|
|
|
|
-- Lua script to invalidate ANN from redis |
|
|
|
-- Uses the following keys |
|
|
@@ -149,7 +149,7 @@ local redis_lua_script_maybe_invalidate = [[ |
|
|
|
redis.call('DEL', KEYS[1] .. '_hostname') |
|
|
|
return 1 |
|
|
|
]] |
|
|
|
local redis_maybe_invalidate_sha = nil |
|
|
|
local redis_maybe_invalidate_id = nil |
|
|
|
|
|
|
|
-- Lua script to invalidate ANN from redis |
|
|
|
-- Uses the following keys |
|
|
@@ -163,7 +163,7 @@ local redis_lua_script_locked_invalidate = [[ |
|
|
|
redis.call('DEL', KEYS[1] .. '_hostname') |
|
|
|
return 1 |
|
|
|
]] |
|
|
|
local redis_locked_invalidate_sha = nil |
|
|
|
local redis_locked_invalidate_id = nil |
|
|
|
|
|
|
|
-- Lua script to invalidate ANN from redis |
|
|
|
-- Uses the following keys |
|
|
@@ -182,7 +182,7 @@ local redis_lua_script_maybe_lock = [[ |
|
|
|
redis.call('SET', KEYS[1] .. '_hostname', KEYS[4]) |
|
|
|
return 1 |
|
|
|
]] |
|
|
|
local redis_maybe_lock_sha = nil |
|
|
|
local redis_maybe_lock_id = nil |
|
|
|
|
|
|
|
-- Lua script to save and unlock ANN in redis |
|
|
|
-- Uses the following keys |
|
|
@@ -200,119 +200,23 @@ local redis_lua_script_save_unlock = [[ |
|
|
|
redis.call('EXPIRE', KEYS[1] .. '_version', KEYS[3]) |
|
|
|
return 1 |
|
|
|
]] |
|
|
|
local redis_save_unlock_sha = nil |
|
|
|
local redis_save_unlock_id = nil |
|
|
|
|
|
|
|
local redis_params |
|
|
|
|
|
|
|
local function load_scripts(cfg, ev_base, on_load_cb) |
|
|
|
local function can_train_sha_cb(err, data) |
|
|
|
if err or not data or type(data) ~= 'string' then |
|
|
|
rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err) |
|
|
|
else |
|
|
|
redis_can_train_sha = tostring(data) |
|
|
|
end |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
can_train_sha_cb, --callback |
|
|
|
'SCRIPT', -- command |
|
|
|
{'LOAD', redis_lua_script_can_train} -- arguments |
|
|
|
) |
|
|
|
|
|
|
|
local function maybe_load_sha_cb(err, data) |
|
|
|
if err or not data or type(data) ~= 'string' then |
|
|
|
rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err) |
|
|
|
else |
|
|
|
redis_maybe_load_sha = tostring(data) |
|
|
|
|
|
|
|
if on_load_cb then |
|
|
|
rspamd_config:add_periodic(ev_base, 0.0, |
|
|
|
function(_cfg, _ev_base) |
|
|
|
return on_load_cb(_cfg, _ev_base) |
|
|
|
end) |
|
|
|
end |
|
|
|
end |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
maybe_load_sha_cb, --callback |
|
|
|
'SCRIPT', -- command |
|
|
|
{'LOAD', redis_lua_script_maybe_load} -- arguments |
|
|
|
) |
|
|
|
|
|
|
|
local function maybe_invalidate_sha_cb(err, data) |
|
|
|
if err or not data or type(data) ~= 'string' then |
|
|
|
rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err) |
|
|
|
else |
|
|
|
redis_maybe_invalidate_sha = tostring(data) |
|
|
|
end |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
maybe_invalidate_sha_cb, --callback |
|
|
|
'SCRIPT', -- command |
|
|
|
{'LOAD', redis_lua_script_maybe_invalidate} -- arguments |
|
|
|
) |
|
|
|
|
|
|
|
local function locked_invalidate_sha_cb(err, data) |
|
|
|
if err or not data or type(data) ~= 'string' then |
|
|
|
rspamd_logger.errx(cfg, 'cannot save redis locked invalidate script: %s', err) |
|
|
|
else |
|
|
|
redis_locked_invalidate_sha = tostring(data) |
|
|
|
end |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
locked_invalidate_sha_cb, --callback |
|
|
|
'SCRIPT', -- command |
|
|
|
{'LOAD', redis_lua_script_locked_invalidate} -- arguments |
|
|
|
) |
|
|
|
|
|
|
|
local function maybe_lock_sha_cb(err, data) |
|
|
|
if err or not data or type(data) ~= 'string' then |
|
|
|
rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err) |
|
|
|
else |
|
|
|
redis_maybe_lock_sha = tostring(data) |
|
|
|
end |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
maybe_lock_sha_cb, --callback |
|
|
|
'SCRIPT', -- command |
|
|
|
{'LOAD', redis_lua_script_maybe_lock} -- arguments |
|
|
|
) |
|
|
|
|
|
|
|
local function save_unlock_sha_cb(err, data) |
|
|
|
if err or not data or type(data) ~= 'string' then |
|
|
|
rspamd_logger.errx(cfg, 'cannot save redis save script: %s', err) |
|
|
|
else |
|
|
|
redis_save_unlock_sha = tostring(data) |
|
|
|
end |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
redis_params, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
save_unlock_sha_cb, --callback |
|
|
|
'SCRIPT', -- command |
|
|
|
{'LOAD', redis_lua_script_save_unlock} -- arguments |
|
|
|
) |
|
|
|
local function load_scripts(params) |
|
|
|
redis_can_train_id = rspamd_redis.add_redis_script(redis_lua_script_can_train, |
|
|
|
params) |
|
|
|
redis_maybe_load_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_load, |
|
|
|
params) |
|
|
|
redis_maybe_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_invalidate, |
|
|
|
params) |
|
|
|
redis_locked_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_locked_invalidate, |
|
|
|
params) |
|
|
|
redis_maybe_lock_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_lock, |
|
|
|
params) |
|
|
|
redis_save_unlock_id = rspamd_redis.add_redis_script(redis_lua_script_save_unlock, |
|
|
|
params) |
|
|
|
end |
|
|
|
|
|
|
|
local function gen_fann_prefix(rule, id) |
|
|
@@ -490,9 +394,6 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) |
|
|
|
local function redis_invalidate_cb(_err, _data) |
|
|
|
if _err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) |
|
|
|
if string.match(_err, 'NOSCRIPT') then |
|
|
|
load_scripts(rspamd_config, ev_base, nil) |
|
|
|
end |
|
|
|
elseif type(_data) == 'string' then |
|
|
|
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) |
|
|
|
fanns[id].version = 0 |
|
|
@@ -500,15 +401,10 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) |
|
|
|
end |
|
|
|
-- Invalidate ANN |
|
|
|
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix) |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
redis_invalidate_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_maybe_invalidate_sha, 1, prefix} |
|
|
|
) |
|
|
|
rspamd_redis.exec_redis_script(redis_maybe_invalidate_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
redis_invalidate_cb, |
|
|
|
{prefix}) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
@@ -589,9 +485,6 @@ local function fann_train_callback(rule, task, score, required_score, id) |
|
|
|
else |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err) |
|
|
|
if string.match(err, 'NOSCRIPT') then |
|
|
|
load_scripts(rspamd_config, task:get_ev_base(), nil) |
|
|
|
end |
|
|
|
elseif tonumber(data) < 0 then |
|
|
|
rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s", |
|
|
|
fname, k, -tonumber(data)) |
|
|
@@ -599,15 +492,10 @@ local function fann_train_callback(rule, task, score, required_score, id) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
rspamd_redis.rspamd_redis_make_request(task, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
can_train_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_can_train_sha, '4', gen_fann_prefix(rule, nil), |
|
|
|
suffix, k, tostring(train_opts.max_trains)} -- arguments |
|
|
|
) |
|
|
|
rspamd_redis.exec_redis_script(redis_can_train_id, |
|
|
|
{task = task, is_write = true}, |
|
|
|
can_train_cb, |
|
|
|
{gen_fann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)}) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
@@ -637,9 +525,6 @@ local function train_fann(rule, _, ev_base, elt, worker) |
|
|
|
'DEL', -- command |
|
|
|
{prefix .. '_locked'} |
|
|
|
) |
|
|
|
if string.match(err, 'NOSCRIPT') then |
|
|
|
load_scripts(rspamd_config, ev_base, nil) |
|
|
|
end |
|
|
|
else |
|
|
|
rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix) |
|
|
|
end |
|
|
@@ -674,15 +559,10 @@ local function train_fann(rule, _, ev_base, elt, worker) |
|
|
|
fanns[elt].version = fanns[elt].version + 1 |
|
|
|
fanns[elt].fann = fanns[elt].fann_train |
|
|
|
fanns[elt].fann_train = nil |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
redis_save_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)} |
|
|
|
) |
|
|
|
rspamd_redis.exec_redis_script(redis_save_unlock_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
redis_save_cb, |
|
|
|
{prefix, tostring(ann_data), tostring(rule.ann_expire)}) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
@@ -711,15 +591,10 @@ local function train_fann(rule, _, ev_base, elt, worker) |
|
|
|
fanns[elt].version = fanns[elt].version + 1 |
|
|
|
fanns[elt].fann = fanns[elt].fann_train |
|
|
|
fanns[elt].fann_train = nil |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
redis_save_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)} |
|
|
|
) |
|
|
|
rspamd_redis.exec_redis_script(redis_save_unlock_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
redis_save_cb, |
|
|
|
{prefix, tostring(ann_data), tostring(rule.ann_expire)}) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
@@ -768,15 +643,10 @@ local function train_fann(rule, _, ev_base, elt, worker) |
|
|
|
end |
|
|
|
-- Invalidate ANN |
|
|
|
rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix) |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
redis_invalidate_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_locked_invalidate_sha, 1, prefix} |
|
|
|
) |
|
|
|
rspamd_redis.exec_redis_script(redis_locked_invalidate_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
redis_invalidate_cb, |
|
|
|
{prefix}) |
|
|
|
else |
|
|
|
if use_torch then |
|
|
|
-- For torch we do not need to mix samples as they would be flushed |
|
|
@@ -874,9 +744,6 @@ local function train_fann(rule, _, ev_base, elt, worker) |
|
|
|
if err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', |
|
|
|
prefix, err) |
|
|
|
if string.match(err, 'NOSCRIPT') then |
|
|
|
load_scripts(rspamd_config, ev_base, nil) |
|
|
|
end |
|
|
|
elseif type(data) == 'number' then |
|
|
|
-- Can train ANN |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
@@ -926,16 +793,10 @@ local function train_fann(rule, _, ev_base, elt, worker) |
|
|
|
rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix) |
|
|
|
return |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
true, -- is write |
|
|
|
redis_lock_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_maybe_lock_sha, '4', prefix, tostring(os.time()), |
|
|
|
tostring(rule.lock_expire), rspamd_util.get_hostname()} |
|
|
|
) |
|
|
|
rspamd_redis.exec_redis_script(redis_maybe_lock_id, |
|
|
|
{ev_base = ev_base, is_write = true}, |
|
|
|
redis_lock_cb, |
|
|
|
{prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()}) |
|
|
|
end |
|
|
|
|
|
|
|
local function maybe_train_fanns(rule, cfg, ev_base, worker) |
|
|
@@ -979,10 +840,6 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
if not redis_maybe_load_sha then |
|
|
|
-- Plan new event early |
|
|
|
return 1.0 |
|
|
|
end |
|
|
|
-- First we need to get all fanns stored in our Redis |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
@@ -1009,9 +866,6 @@ local function check_fanns(rule, _, ev_base) |
|
|
|
if _err then |
|
|
|
rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', |
|
|
|
elt, _err) |
|
|
|
if string.match(_err, 'NOSCRIPT') then |
|
|
|
load_scripts(rspamd_config, ev_base, nil) |
|
|
|
end |
|
|
|
elseif _data and type(_data) == 'table' then |
|
|
|
load_or_invalidate_fann(rule, _data, elt, ev_base) |
|
|
|
else |
|
|
@@ -1028,24 +882,15 @@ local function check_fanns(rule, _, ev_base) |
|
|
|
local_ver = fanns[elt].version |
|
|
|
end |
|
|
|
end |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
|
rule.redis, |
|
|
|
nil, |
|
|
|
false, -- is write |
|
|
|
redis_update_cb, --callback |
|
|
|
'EVALSHA', -- command |
|
|
|
{redis_maybe_load_sha, 2, gen_fann_prefix(rule, elt), tostring(local_ver)} |
|
|
|
) |
|
|
|
rspamd_redis.exec_redis_script(redis_maybe_load_id, |
|
|
|
{ev_base = ev_base, is_write = false}, |
|
|
|
redis_update_cb, |
|
|
|
{gen_fann_prefix(rule, elt), tostring(local_ver)}) |
|
|
|
end, |
|
|
|
data) |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
if not redis_maybe_load_sha then |
|
|
|
-- Plan new event early |
|
|
|
return 1.0 |
|
|
|
end |
|
|
|
-- First we need to get all fanns stored in our Redis |
|
|
|
rspamd_redis.redis_make_request_taskless(ev_base, |
|
|
|
rspamd_config, |
|
|
@@ -1187,10 +1032,9 @@ else |
|
|
|
|
|
|
|
-- Add training scripts |
|
|
|
for _,rule in pairs(settings.rules) do |
|
|
|
load_scripts(rule.redis) |
|
|
|
rspamd_config:add_on_load(function(cfg, ev_base, worker) |
|
|
|
load_scripts(cfg, ev_base, function(_, _) |
|
|
|
return check_fanns(rule, cfg, ev_base) |
|
|
|
end) |
|
|
|
check_fanns(rule, cfg, ev_base) |
|
|
|
|
|
|
|
if worker:get_name() == 'controller' and worker:get_index() == 0 then |
|
|
|
-- We also want to train neural nets when they have enough data |