From ca75dbad6bcb55084cc1104c2dc5ef1109c270c0 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 18 Oct 2019 17:08:44 +0100 Subject: [PATCH] [Feature] Neural: Add sampling when storing training vectors --- src/plugins/lua/neural.lua | 114 +++++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 42 deletions(-) diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 0b93cd4a7..7acb0eca3 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -97,7 +97,7 @@ end -- key1 - ann key -- key2 - spam or ham -- key3 - maximum trains --- returns 1 or 0: 1 - allow learn, 0 - not allow learn +-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn local redis_lua_script_can_store_train_vec = [[ local prefix = KEYS[1] local locked = redis.call('HGET', prefix, 'lock') @@ -114,19 +114,33 @@ local redis_lua_script_can_store_train_vec = [[ if KEYS[2] == 'spam' then if nspam <= lim then - return tostring(nspam) - else - return tostring(-(nspam)) + if nspam > nham then + -- Apply sampling + local skip_rate = 1.0 - nham / (nspam + 1) + if math.random() < skip_rate then + return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)} + end + end + return {tostring(nspam),'can learn'} + else -- Enough learns + return {tostring(-(nspam)),'too many spam samples'} end else if nham <= lim then - return tostring(nham) + if nsham > nspam then + -- Apply sampling + local skip_rate = 1.0 - nspam / (nham + 1) + if math.random() < skip_rate then + return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)} + end + end + return {tostring(nham),'can learn'} else - return tostring(-(nham)) + return {tostring(-(nham)),'too many ham samples'} end end - return tostring(0) + return {tostring(0),'bad input'} ]] local redis_can_store_train_vec_id = nil @@ -416,45 +430,50 @@ local function ann_push_task_result(rule, task, verdict, score, set) if learn_spam then learn_type = 'spam' else learn_type = 'ham' end local function can_train_cb(err, data) - if not err and tonumber(data) >= 0 then - local coin = math.random() - if coin < 1.0 - train_opts.train_prob then - rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) - return - end - local vec = result_to_vector(task, set) + if not err and type(data) == 'table' then + local nsamples,reason = tonumber(data[1]),data[2] - local str = rspamd_util.zstd_compress(table.concat(vec, ';')) - local target_key = set.ann.redis_key .. '_' .. learn_type + if nsamples > 0 then + local coin = math.random() - local function learn_vec_cb(_err) - if _err then - rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', - rule.prefix, set.name, _err) - else - lua_util.debugm(N, task, - "add train data for ANN rule " .. - "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", - rule.prefix, set.name, learn_type, #vec, target_key, #str) + if coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return end - end - lua_redis.redis_make_request(task, - rule.redis, - nil, - true, -- is write - learn_vec_cb, --callback - 'LPUSH', -- command - { target_key, str } -- arguments - ) - else - if err then - rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', - rule.prefix, set.name, err) - elseif tonumber(data) < 0 then - rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s", - rule.prefix, set.name, learn_type, -tonumber(data)) + local vec = result_to_vector(task, set) + + local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + local target_key = set.ann.redis_key .. '_' .. learn_type + + local function learn_vec_cb(_err) + if _err then + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + rule.prefix, set.name, _err) + else + lua_util.debugm(N, task, + "add train data for ANN rule " .. + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, target_key, #str) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + learn_vec_cb, --callback + 'LPUSH', -- command + { target_key, str } -- arguments + ) + else + -- Negative result returned + rspamd_logger.infox(task, "cannot learn ANN %s:%s: %s (%s vectors stored)", + rule.prefix, set.name, learn_type, reason, -tonumber(nsamples)) end + else + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', + rule.prefix, set.name, err) end end @@ -466,7 +485,11 @@ local function ann_push_task_result(rule, task, verdict, score, set) lua_redis.exec_redis_script(redis_can_store_train_vec_id, {task = task, is_write = true}, can_train_cb, - { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)}) + { + set.ann.redis_key, + learn_type, + tostring(train_opts.max_trains), + }) else lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s', skip_reason) @@ -1128,6 +1151,13 @@ local function ann_push_vector(task) return end + if score ~= score then + lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)', + verdict) + + return + end + for _,rule in pairs(settings.rules) do local set = get_rule_settings(task, rule) -- 2.39.5