diff options
-rw-r--r-- | src/plugins/lua/neural.lua | 174 |
1 files changed, 130 insertions, 44 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index e6c52912a..41a9b4f07 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -42,8 +42,11 @@ local default_options = { autotrain = true, train_prob = 1.0, learn_threads = 1, + learn_mode = 'balanced', -- Possible values: balanced, proportional learning_rate = 0.01, - classes_bias = 0.0, -- What difference is allowed between classes (1:1 proportion means 0 bias) + classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) + spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) + ham_skip_prob = 0.0, -- proportional mode: ham skip probability }, watch_interval = 60.0, lock_expire = 600, @@ -97,26 +100,21 @@ end -- Lua script that checks if we can store a new training vector -- Uses the following keys: -- key1 - ann key --- key2 - spam or ham --- key3 - maximum trains --- key4 - sampling coin (as Redis scripts do not allow math.random calls) --- key5 - classes bias --- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn -local redis_lua_script_can_store_train_vec = [[ +-- returns nspam,nham (or nil if locked) +local redis_lua_script_vectors_len = [[ local prefix = KEYS[1] local locked = redis.call('HGET', prefix, 'lock') - if locked then return {tostring(-1),'locked by another process till: ' .. locked} end + if locked then return false end local nspam = 0 local nham = 0 - local lim = tonumber(KEYS[3]) - local coin = tonumber(KEYS[4]) - local classes_bias = tonumber(KEYS[5]) local ret = redis.call('LLEN', prefix .. '_spam') if ret then nspam = tonumber(ret) end ret = redis.call('LLEN', prefix .. '_ham') if ret then nham = tonumber(ret) end + return {nspam,nham} + if KEYS[2] == 'spam' then if nspam <= lim then if nspam > nham then @@ -147,7 +145,7 @@ local redis_lua_script_can_store_train_vec = [[ return {tostring(-1),'bad input'} ]] -local redis_can_store_train_vec_id = nil +local redis_lua_script_vectors_len_id = nil -- Lua script to invalidate ANNs by rank -- Uses the following keys @@ -220,7 +218,7 @@ local redis_save_unlock_id = nil local redis_params local function load_scripts(params) - redis_can_store_train_vec_id = lua_redis.add_redis_script(redis_lua_script_can_store_train_vec, + redis_lua_script_vectors_len_id = lua_redis.add_redis_script(redis_lua_script_vectors_len, params) redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, params) @@ -379,6 +377,88 @@ local function create_ann(n, nlayers) return rspamd_kann.new.kann(t) end +local function can_push_train_vector(rule, task, learn_type, nspam, nham) + local train_opts = rule.train + local coin = math.random() + + if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return false + end + + if train_opts.learn_mode == 'balanced' then + -- Keep balanced training set based on number of spam and ham samples + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if nspam > nham then + -- Apply sampling + local skip_rate = 1.0 - nham / (nspam + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type, + skip_rate - train_opts.classes_bias) + return false + end + end + return true + else -- Enough learns + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', learn_type, + nspam) + end + else + if nham <= train_opts.max_trains then + if nham > nspam then + -- Apply sampling + local skip_rate = 1.0 - nspam / (nham + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; probability %s', learn_type, + skip_rate - train_opts.classes_bias) + return false + end + end + return true + else + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, + nham) + end + end + else + -- Probabilistic learn mode, we just skip learn if we already have enough samples or + -- if our coin drop is less than desired probability + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if train_opts.spam_skip_prob then + if coin <= train_opts.spam_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, + coin, train_opts.spam_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, + nspam, train_opts.max_trains) + end + else + if nham <= train_opts.max_trains then + if train_opts.ham_skip_prob then + if coin <= train_opts.ham_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilisticaly; probability %s (%s skip chance)', learn_type, + coin, train_opts.ham_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, + nham, train_opts.max_trains) + end + end + end + + return false +end local function ann_push_task_result(rule, task, verdict, score, set) local train_opts = rule.train @@ -436,17 +516,12 @@ local function ann_push_task_result(rule, task, verdict, score, set) local learn_type if learn_spam then learn_type = 'spam' else learn_type = 'ham' end - local function can_train_cb(err, data) + local function vectors_len_cb(err, data) if not err and type(data) == 'table' then - local nsamples,reason = tonumber(data[1]),data[2] + local nspam,nham = data[1],data[2] - if nsamples >= 0 then - local coin = math.random() - - if coin < 1.0 - train_opts.train_prob then - rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) - return - end + if nspam > 0 and nham > 0 and + can_push_train_vector(rule, task, learn_type, nspam, nham) then local vec = result_to_vector(task, set) @@ -473,15 +548,15 @@ local function ann_push_task_result(rule, task, verdict, score, set) 'LPUSH', -- command { target_key, str } -- arguments ) - else - -- Negative result returned - rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: %s (%s vectors stored)", - learn_type, rule.prefix, set.name, set.ann.redis_key, reason, -tonumber(nsamples)) end else if err then rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', rule.prefix, set.name, err) + elseif type(data) == 'userdata' then + -- nil return value + rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning", + learn_type, rule.prefix, set.name, set.ann.redis_key) else rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' .. 'please remove this key from Redis manually if you perform upgrade from the previous version', @@ -500,15 +575,11 @@ local function ann_push_task_result(rule, task, verdict, score, set) set.name) end - lua_redis.exec_redis_script(redis_can_store_train_vec_id, - {task = task, is_write = true}, - can_train_cb, + lua_redis.exec_redis_script(redis_lua_script_vectors_len_id, + {task = task, is_write = false}, + vectors_len_cb, { set.ann.redis_key, - learn_type, - tostring(train_opts.max_trains), - tostring(math.random()), - tostring(train_opts.classes_bias) }) else lua_util.debugm(N, task, @@ -1059,18 +1130,33 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) -- at least (10 * (1 - 0.25)) = 8 trains local max_len = math.max(lua_util.unpack(lua_util.values(lens))) - local len_bias_check_pred = function(_, l) - return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias) - end - if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then - rspamd_logger.debugm(N, rspamd_config, - 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', - ann_key, lens, rule.train.max_trains, what) - cont_cb() + + if rule.train.learn_type == 'balanced' then + local len_bias_check_pred = function(_, l) + return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias) + end + if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then + rspamd_logger.debugm(N, rspamd_config, + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) + end else - rspamd_logger.debugm(N, rspamd_config, - 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', - ann_key, what, lens, rule.train.max_trains) + -- Probabilistic mode, just ensure that at least one vector is okay + if max_len >= rule.train.max_trains then + rspamd_logger.debugm(N, rspamd_config, + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) + end end else |