From: Vsevolod Stakhov Date: Mon, 15 Jul 2019 14:42:59 +0000 (+0100) Subject: [Minor] Neural: Various fixes X-Git-Tag: 2.0~589 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=b8a7db17236cf6fc3757e593c5ba2b3429ed1dc6;p=rspamd.git [Minor] Neural: Various fixes --- diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 51a33e6e1..2e4c8e7cc 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -273,7 +273,7 @@ local function new_ann_profile(task, rule, set, version) local profile = { symbols = set.symbols, redis_key = ann_key, - version = version or 0, + version = version, digest = set.digest, distance = 0 -- Since we are using our own profile } @@ -334,8 +334,8 @@ local function ann_scores_filter(task) score = out[1] local symscore = string.format('%.3f', score) - lua_util.debugm(N, task, '%s:%s ann score: %s', - rule.prefix, set.name, symscore) + lua_util.debugm(N, task, '%s:%s:%s ann score: %s', + rule.prefix, set.name, set.ann.version, symscore) if score > 0 then local result = score @@ -425,6 +425,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) local vec = result_to_vector(task, set) local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + local target_key = set.ann.redis_key .. '_' .. learn_type local function learn_vec_cb(_err) if _err then @@ -432,8 +433,9 @@ local function ann_push_task_result(rule, task, verdict, score, set) rule.prefix, set.name, _err) else lua_util.debugm(N, task, - "add train data for ANN rule %s:%s, save %s vector of %s elts; %s bytes compressed", - rule.prefix, set.name, learn_type, #vec, #str) + "add train data for ANN rule" .. + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, target_key, #str) end end @@ -443,7 +445,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) true, -- is write learn_vec_cb, --callback 'LPUSH', -- command - { set.ann.redis_key .. '_' .. learn_type, str} -- arguments + { target_key, str } -- arguments ) else if err then @@ -458,7 +460,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) if not set.ann then -- Need to create or load a profile corresponding to the current configuration - set.ann = new_ann_profile(task, rule, set) + set.ann = new_ann_profile(task, rule, set, 0) end -- Check if we can learn lua_redis.exec_redis_script(redis_can_store_train_vec_id, @@ -606,8 +608,8 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve {ann_key, 'lock'} ) else - rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes', - rule.prefix, set.name, #data) + rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes; redis key: %s', + rule.prefix, set.name, #data, ann_key) local ann_data = rspamd_util.zstd_compress(data) if not set.ann then set.ann = { @@ -800,6 +802,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) if ann then set.ann = { + digest = profile.digest, version = profile.version, symbols = profile.symbols, distance = min_diff, @@ -943,18 +946,21 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) ann_key) -- Create continuation closure - local redis_len_cb_gen = function(cont_cb) + local redis_len_cb_gen = function(cont_cb, what) return function(err, data) if err then rspamd_logger.errx(rspamd_config, - 'cannot get ANN trains %s from redis: %s', ann_key, err) + 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) elseif data and type(data) == 'number' or type(data) == 'string' then if tonumber(data) and tonumber(data) >= rule.train.max_trains then + rspamd_logger.debugm(N, rspamd_config, + 'ANN %s has %s %s learn vectors (%s required)', + ann_key, tonumber(data), what, rule.train.max_trains) cont_cb() else rspamd_logger.debugm(N, rspamd_config, - 'no need to learn ANN %s %s learn vectors (%s required)', - ann_key, tonumber(data), rule.train.max_trains) + 'no need to learn ANN %s %s %s learn vectors (%s required)', + ann_key, tonumber(data), what, rule.train.max_trains) end end end @@ -975,7 +981,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) rule.redis, nil, false, -- is write - redis_len_cb_gen(initiate_train), --callback + redis_len_cb_gen(initiate_train, 'ham'), --callback 'LLEN', -- command {ann_key .. '_ham'} ) @@ -986,7 +992,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) rule.redis, nil, false, -- is write - redis_len_cb_gen(check_ham_len), --callback + redis_len_cb_gen(check_ham_len, 'spam'), --callback 'LLEN', -- command {ann_key .. '_spam'} ) @@ -1016,14 +1022,15 @@ local function load_ann_profile(element) end -- Function to check or load ANNs from Redis -local function check_anns(worker, cfg, ev_base, rule, process_callback) +local function check_anns(worker, cfg, ev_base, rule, process_callback, what) for _,set in pairs(rule.settings) do local function members_cb(err, data) if err then rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', err) elseif type(data) == 'table' then - lua_util.debugm(N, cfg, 'process element %s:%s', rule.prefix, set.name) + lua_util.debugm(N, cfg, '%s: process element %s:%s', + what, rule.prefix, set.name) process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) end end @@ -1272,7 +1279,7 @@ end rspamd_config:register_symbol({ name = 'NEURAL_LEARN', - type = 'idempotent,nostat', + type = 'idempotent,nostat,explicit_disable', priority = 5, callback = ann_push_vector }) @@ -1284,10 +1291,13 @@ for _,rule in pairs(settings.rules) do rspamd_config:add_post_init(process_rules_settings) -- This function will check ANNs in Redis when a worker is loaded rspamd_config:add_on_load(function(cfg, ev_base, worker) - rspamd_config:add_periodic(ev_base, 0.0, - function(_, _) - return check_anns(worker, cfg, ev_base, rule, process_existing_ann) - end) + if worker:is_scanner() then + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann, + 'try_load_ann') + end) + end if worker:is_primary_controller() then -- We also want to train neural nets when they have enough data @@ -1295,7 +1305,8 @@ for _,rule in pairs(settings.rules) do function(_, _) -- Clean old ANNs cleanup_anns(rule, cfg, ev_base) - return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, + 'try_train_ann') end) end end)