summaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/neural.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-03-16 11:15:12 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-03-16 11:15:12 +0000
commit195e79d69ba3b0ab8b21e9f81eb76ee98e3858f6 (patch)
tree79f597dce030e2a2eb7f6b4ef5adbe6c55d8a201 /src/plugins/lua/neural.lua
parent50e1bfba88a19d9d0fe225fc5900beede64b03a2 (diff)
downloadrspamd-195e79d69ba3b0ab8b21e9f81eb76ee98e3858f6.tar.gz
rspamd-195e79d69ba3b0ab8b21e9f81eb76ee98e3858f6.zip
[Feature] Neural: Introduce classes bias that allows non-equal classes learning
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r--src/plugins/lua/neural.lua49
1 files changed, 36 insertions, 13 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 1897f0843..affb07307 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -43,6 +43,7 @@ local default_options = {
train_prob = 1.0,
learn_threads = 1,
learning_rate = 0.01,
+ classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias)
},
watch_interval = 60.0,
lock_expire = 600,
@@ -99,6 +100,7 @@ end
-- key2 - spam or ham
-- key3 - maximum trains
-- key4 - sampling coin (as Redis scripts do not allow math.random calls)
+-- key5 - classes bias
-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
local redis_lua_script_can_store_train_vec = [[
local prefix = KEYS[1]
@@ -108,6 +110,7 @@ local redis_lua_script_can_store_train_vec = [[
local nham = 0
local lim = tonumber(KEYS[3])
local coin = tonumber(KEYS[4])
+ local classes_bias = tonumber(KEYS[5])
local ret = redis.call('LLEN', prefix .. '_spam')
if ret then nspam = tonumber(ret) end
@@ -119,8 +122,8 @@ local redis_lua_script_can_store_train_vec = [[
if nspam > nham then
-- Apply sampling
local skip_rate = 1.0 - nham / (nspam + 1)
- if coin < skip_rate then
- return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
+ if coin < skip_rate - classes_bias then
+ return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
end
end
return {tostring(nspam),'can learn'}
@@ -132,8 +135,8 @@ local redis_lua_script_can_store_train_vec = [[
if nham > nspam then
-- Apply sampling
local skip_rate = 1.0 - nspam / (nham + 1)
- if coin < skip_rate then
- return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
+ if coin < skip_rate - classes_bias then
+ return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate - classes_bias)}
end
end
return {tostring(nham),'can learn'}
@@ -505,6 +508,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
learn_type,
tostring(train_opts.max_trains),
tostring(math.random()),
+ tostring(train_opts.classes_bias)
})
else
lua_util.debugm(N, task,
@@ -1014,6 +1018,10 @@ end
local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
local my_symbols = set.symbols
local sel_elt
+ local lens = {
+ spam = 0,
+ ham = 0,
+ }
for _,elt in fun.iter(profiles) do
if elt and elt.symbols then
@@ -1040,21 +1048,36 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
rspamd_logger.errx(rspamd_config,
'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
elseif data and type(data) == 'number' or type(data) == 'string' then
- if tonumber(data) and tonumber(data) >= rule.train.max_trains then
- if is_final then
+ local ntrains = tonumber(data) or 0
+ lens[what] = ntrains
+ if is_final then
+ local unpack = rawget(table, "unpack") or unpack
+ -- Ensure that we have the following:
+ -- one class has reached max_trains
+ -- other class(es) are at least as full as classes_bias
+ -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
+ -- one class must have 10 or more trains whilst another should have
+ -- at least (10 * (1 - 0.25)) = 8 trains
+
+ local max_len = math.max(unpack(lens))
+ local len_bias_check_pred = function(l)
+ return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+ end
+ if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
rspamd_logger.debugm(N, rspamd_config,
'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, tonumber(data), rule.train.max_trains, what)
+ ann_key, lens, rule.train.max_trains, what)
+ cont_cb()
else
rspamd_logger.debugm(N, rspamd_config,
- 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
- what, ann_key, tonumber(data), rule.train.max_trains)
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
end
- cont_cb()
+
else
rspamd_logger.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, tonumber(data), rule.train.max_trains)
+ 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+ what, ann_key, ntrains, rule.train.max_trains)
end
end
end
@@ -1064,7 +1087,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
local function initiate_train()
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s required learn vectors',
- ann_key, rule.train.max_trains)
+ ann_key, lens)
do_train_ann(worker, ev_base, rule, set, ann_key)
end