summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/plugins/lua/neural.lua174
1 files changed, 130 insertions, 44 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index e6c52912a..41a9b4f07 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -42,8 +42,11 @@ local default_options = {
autotrain = true,
train_prob = 1.0,
learn_threads = 1,
+ learn_mode = 'balanced', -- Possible values: balanced, proportional
learning_rate = 0.01,
- classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias)
+ classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+ spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+ ham_skip_prob = 0.0, -- proportional mode: ham skip probability
},
watch_interval = 60.0,
lock_expire = 600,
@@ -97,26 +100,21 @@ end
-- Lua script that checks if we can store a new training vector
-- Uses the following keys:
-- key1 - ann key
--- key2 - spam or ham
--- key3 - maximum trains
--- key4 - sampling coin (as Redis scripts do not allow math.random calls)
--- key5 - classes bias
--- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
-local redis_lua_script_can_store_train_vec = [[
+-- returns nspam,nham (or nil if locked)
+local redis_lua_script_vectors_len = [[
local prefix = KEYS[1]
local locked = redis.call('HGET', prefix, 'lock')
- if locked then return {tostring(-1),'locked by another process till: ' .. locked} end
+ if locked then return false end
local nspam = 0
local nham = 0
- local lim = tonumber(KEYS[3])
- local coin = tonumber(KEYS[4])
- local classes_bias = tonumber(KEYS[5])
local ret = redis.call('LLEN', prefix .. '_spam')
if ret then nspam = tonumber(ret) end
ret = redis.call('LLEN', prefix .. '_ham')
if ret then nham = tonumber(ret) end
+ return {nspam,nham}
+
if KEYS[2] == 'spam' then
if nspam <= lim then
if nspam > nham then
@@ -147,7 +145,7 @@ local redis_lua_script_can_store_train_vec = [[
return {tostring(-1),'bad input'}
]]
-local redis_can_store_train_vec_id = nil
+local redis_lua_script_vectors_len_id = nil
-- Lua script to invalidate ANNs by rank
-- Uses the following keys
@@ -220,7 +218,7 @@ local redis_save_unlock_id = nil
local redis_params
local function load_scripts(params)
- redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec,
+ redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len,
params)
redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
params)
@@ -379,6 +377,88 @@ local function create_ann(n, nlayers)
return rspamd_kann.new.kann(t)
end
+local function can_push_train_vector(rule, task, learn_type, nspam, nham)
+ local train_opts = rule.train
+ local coin = math.random()
+
+ if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return false
+ end
+
+ if train_opts.learn_mode == 'balanced' then
+ -- Keep balanced training set based on number of spam and ham samples
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if nspam > nham then
+ -- Apply sampling
+ local skip_rate = 1.0 - nham / (nspam + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
+ skip_rate - train_opts.classes_bias)
+ return false
+ end
+ end
+ return true
+ else -- Enough learns
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', learn_type,
+ nspam)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if nham > nspam then
+ -- Apply sampling
+ local skip_rate = 1.0 - nspam / (nham + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type,
+ skip_rate - train_opts.classes_bias)
+ return false
+ end
+ end
+ return true
+ else
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
+ nham)
+ end
+ end
+ else
+ -- Probabilistic learn mode, we just skip learn if we already have enough samples or
+ -- if our coin drop is less than desired probability
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if train_opts.spam_skip_prob then
+ if coin <= train_opts.spam_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.spam_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
+ nspam, train_opts.max_trains)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if train_opts.ham_skip_prob then
+ if coin <= train_opts.ham_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.ham_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
+ nham, train_opts.max_trains)
+ end
+ end
+ end
+
+ return false
+end
local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train
@@ -436,17 +516,12 @@ local function ann_push_task_result(rule, task, verdict, score, set)
local learn_type
if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
- local function can_train_cb(err, data)
+ local function vectors_len_cb(err, data)
if not err and type(data) == 'table' then
- local nsamples,reason = tonumber(data[1]),data[2]
+ local nspam,nham = data[1],data[2]
- if nsamples >= 0 then
- local coin = math.random()
-
- if coin < 1.0 - train_opts.train_prob then
- rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
- return
- end
+ if nspam > 0 and nham > 0 and
+ can_push_train_vector(rule, task, learn_type, nspam, nham) then
local vec = result_to_vector(task, set)
@@ -473,15 +548,15 @@ local function ann_push_task_result(rule, task, verdict, score, set)
'LPUSH', -- command
{ target_key, str } -- arguments
)
- else
- -- Negative result returned
- rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: %s (%s vectors stored)",
- learn_type, rule.prefix, set.name, set.ann.redis_key, reason, -tonumber(nsamples))
end
else
if err then
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
rule.prefix, set.name, err)
+ elseif type(data) == 'userdata' then
+ -- nil return value
+ rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning",
+ learn_type, rule.prefix, set.name, set.ann.redis_key)
else
rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
'please remove this key from Redis manually if you perform upgrade from the previous version',
@@ -500,15 +575,11 @@ local function ann_push_task_result(rule, task, verdict, score, set)
set.name)
end
- lua_redis.exec_redis_script(redis_can_store_train_vec_id,
- {task = task, is_write = true},
- can_train_cb,
+ lua_redis.exec_redis_script(redis_lua_script_vectors_len_id,
+ {task = task, is_write = false},
+ vectors_len_cb,
{
set.ann.redis_key,
- learn_type,
- tostring(train_opts.max_trains),
- tostring(math.random()),
- tostring(train_opts.classes_bias)
})
else
lua_util.debugm(N, task,
@@ -1059,18 +1130,33 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
-- at least (10 * (1 - 0.25)) = 8 trains
local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
- local len_bias_check_pred = function(_, l)
- return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
- end
- if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
- rspamd_logger.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
- cont_cb()
+
+ if rule.train.learn_type == 'balanced' then
+ local len_bias_check_pred = function(_, l)
+ return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+ end
+ if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
+ rspamd_logger.debugm(N, rspamd_config,
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
+ cont_cb()
+ else
+ rspamd_logger.debugm(N, rspamd_config,
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
+ end
else
- rspamd_logger.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
+ -- Probabilistic mode, just ensure that at least one vector is okay
+ if max_len >= rule.train.max_trains then
+ rspamd_logger.debugm(N, rspamd_config,
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
+ cont_cb()
+ else
+ rspamd_logger.debugm(N, rspamd_config,
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
+ end
end
else