Browse Source

[Feature] Neural: Introduce classes bias that allows non-equal classes learning

tags/2.5
Vsevolod Stakhov 4 years ago
parent
commit
195e79d69b
2 changed files with 39 additions and 18 deletions
  1. 3
    5
      .luacheckrc
  2. 36
    13
      src/plugins/lua/neural.lua

+ 3
- 5
.luacheckrc View File

@@ -34,7 +34,9 @@ globals = {
'rspamadm_ev_base',
'rspamadm_session',
'rspamadm_dns_resolver',
'jit'
'jit',
'table.unpack',
'unpack',
}

ignore = {
@@ -55,10 +57,6 @@ files['/**/src/plugins/lua/reputation.lua'].globals = {
'math.tanh',
}

files['/**/lualib/lua_util.lua'].globals = {
'table.unpack',
'unpack',
}

files['/**/lualib/lua_redis.lua'].globals = {
'rspamadm_ev_base',

+ 36
- 13
src/plugins/lua/neural.lua View File

@@ -43,6 +43,7 @@ local default_options = {
train_prob = 1.0,
learn_threads = 1,
learning_rate = 0.01,
classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias)
},
watch_interval = 60.0,
lock_expire = 600,
@@ -99,6 +100,7 @@ end
-- key2 - spam or ham
-- key3 - maximum trains
-- key4 - sampling coin (as Redis scripts do not allow math.random calls)
-- key5 - classes bias
-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
local redis_lua_script_can_store_train_vec = [[
local prefix = KEYS[1]
@@ -108,6 +110,7 @@ local redis_lua_script_can_store_train_vec = [[
local nham = 0
local lim = tonumber(KEYS[3])
local coin = tonumber(KEYS[4])
local classes_bias = tonumber(KEYS[5])

local ret = redis.call('LLEN', prefix .. '_spam')
if ret then nspam = tonumber(ret) end
@@ -119,8 +122,8 @@ local redis_lua_script_can_store_train_vec = [[
if nspam > nham then
-- Apply sampling
local skip_rate = 1.0 - nham / (nspam + 1)
if coin < skip_rate then
return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
if coin < skip_rate - classes_bias then
return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
end
end
return {tostring(nspam),'can learn'}
@@ -132,8 +135,8 @@ local redis_lua_script_can_store_train_vec = [[
if nham > nspam then
-- Apply sampling
local skip_rate = 1.0 - nspam / (nham + 1)
if coin < skip_rate then
return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
if coin < skip_rate - classes_bias then
return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
end
end
return {tostring(nham),'can learn'}
@@ -505,6 +508,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
learn_type,
tostring(train_opts.max_trains),
tostring(math.random()),
tostring(train_opts.classes_bias)
})
else
lua_util.debugm(N, task,
@@ -1014,6 +1018,10 @@ end
local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
local my_symbols = set.symbols
local sel_elt
local lens = {
spam = 0,
ham = 0,
}

for _,elt in fun.iter(profiles) do
if elt and elt.symbols then
@@ -1040,21 +1048,36 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
rspamd_logger.errx(rspamd_config,
'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
if tonumber(data) and tonumber(data) >= rule.train.max_trains then
if is_final then
local ntrains = tonumber(data) or 0
lens[what] = ntrains
if is_final then
local unpack = rawget(table, "unpack") or unpack
-- Ensure that we have the following:
-- one class has reached max_trains
-- other class(es) are at least as full as classes_bias
-- e.g. if classes_bias = 0.25 and we have 10 max_trains then
-- one class must have 10 or more trains whilst another should have
-- at least (10 * (1 - 0.25)) = 8 trains

local max_len = math.max(unpack(lens))
local len_bias_check_pred = function(l)
return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
end
if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
rspamd_logger.debugm(N, rspamd_config,
'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
ann_key, tonumber(data), rule.train.max_trains, what)
ann_key, lens, rule.train.max_trains, what)
cont_cb()
else
rspamd_logger.debugm(N, rspamd_config,
'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
what, ann_key, tonumber(data), rule.train.max_trains)
'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
ann_key, what, lens, rule.train.max_trains)
end
cont_cb()
else
rspamd_logger.debugm(N, rspamd_config,
'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
ann_key, what, tonumber(data), rule.train.max_trains)
'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
what, ann_key, ntrains, rule.train.max_trains)
end
end
end
@@ -1064,7 +1087,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
local function initiate_train()
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s required learn vectors',
ann_key, rule.train.max_trains)
ann_key, lens)
do_train_ann(worker, ev_base, rule, set, ann_key)
end


Loading…
Cancel
Save