From 195e79d69ba3b0ab8b21e9f81eb76ee98e3858f6 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 16 Mar 2020 11:15:12 +0000 Subject: [Feature] Neural: Introduce classes bias that allows non-equal classes learning --- src/plugins/lua/neural.lua | 49 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 13 deletions(-) (limited to 'src/plugins/lua/neural.lua') diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 1897f0843..affb07307 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -43,6 +43,7 @@ local default_options = { train_prob = 1.0, learn_threads = 1, learning_rate = 0.01, + classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias) }, watch_interval = 60.0, lock_expire = 600, @@ -99,6 +100,7 @@ end -- key2 - spam or ham -- key3 - maximum trains -- key4 - sampling coin (as Redis scripts do not allow math.random calls) +-- key5 - classes bias -- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn local redis_lua_script_can_store_train_vec = [[ local prefix = KEYS[1] @@ -108,6 +110,7 @@ local redis_lua_script_can_store_train_vec = [[ local nham = 0 local lim = tonumber(KEYS[3]) local coin = tonumber(KEYS[4]) + local classes_bias = tonumber(KEYS[5]) local ret = redis.call('LLEN', prefix .. '_spam') if ret then nspam = tonumber(ret) end @@ -119,8 +122,8 @@ local redis_lua_script_can_store_train_vec = [[ if nspam > nham then -- Apply sampling local skip_rate = 1.0 - nham / (nspam + 1) - if coin < skip_rate then - return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)} + if coin < skip_rate - classes_bias then + return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)} end end return {tostring(nspam),'can learn'} @@ -132,8 +135,8 @@ local redis_lua_script_can_store_train_vec = [[ if nham > nspam then -- Apply sampling local skip_rate = 1.0 - nspam / (nham + 1) - if coin < skip_rate then - return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)} + if coin < skip_rate - classes_bias then + return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)} end end return {tostring(nham),'can learn'} @@ -505,6 +508,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) learn_type, tostring(train_opts.max_trains), tostring(math.random()), + tostring(train_opts.classes_bias) }) else lua_util.debugm(N, task, @@ -1014,6 +1018,10 @@ end local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) local my_symbols = set.symbols local sel_elt + local lens = { + spam = 0, + ham = 0, + } for _,elt in fun.iter(profiles) do if elt and elt.symbols then @@ -1040,21 +1048,36 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) rspamd_logger.errx(rspamd_config, 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) elseif data and type(data) == 'number' or type(data) == 'string' then - if tonumber(data) and tonumber(data) >= rule.train.max_trains then - if is_final then + local ntrains = tonumber(data) or 0 + lens[what] = ntrains + if is_final then + local unpack = rawget(table, "unpack") or unpack + -- Ensure that we have the following: + -- one class has reached max_trains + -- other class(es) are at least as full as classes_bias + -- e.g. if classes_bias = 0.25 and we have 10 max_trains then + -- one class must have 10 or more trains whilst another should have + -- at least (10 * (1 - 0.25)) = 8 trains + + local max_len = math.max(unpack(lens)) + local len_bias_check_pred = function(l) + return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias) + end + if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then rspamd_logger.debugm(N, rspamd_config, 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, tonumber(data), rule.train.max_trains, what) + ann_key, lens, rule.train.max_trains, what) + cont_cb() else rspamd_logger.debugm(N, rspamd_config, - 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', - what, ann_key, tonumber(data), rule.train.max_trains) + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) end - cont_cb() + else rspamd_logger.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, tonumber(data), rule.train.max_trains) + 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', + what, ann_key, ntrains, rule.train.max_trains) end end end @@ -1064,7 +1087,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) local function initiate_train() rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s required learn vectors', - ann_key, rule.train.max_trains) + ann_key, lens) do_train_ann(worker, ev_base, rule, set, ann_key) end -- cgit v1.2.3