diff options
-rw-r--r-- | src/plugins/lua/neural.lua | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index fcf9ac5c2..e3518d3bd 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -490,9 +490,7 @@ local function ann_push_task_result(rule, task, verdict, score, set) if not err and type(data) == 'table' then local nspam,nham = data[1],data[2] - if nspam > 0 and nham > 0 and - can_push_train_vector(rule, task, learn_type, nspam, nham) then - + if can_push_train_vector(rule, task, learn_type, nspam, nham) then local vec = result_to_vector(task, set) local str = rspamd_util.zstd_compress(table.concat(vec, ';')) @@ -518,6 +516,11 @@ local function ann_push_task_result(rule, task, verdict, score, set) 'LPUSH', -- command { target_key, str } -- arguments ) + else + lua_util.debugm(N, task, + "do not add %s train data for ANN rule " .. + "%s:%s", + learn_type, rule.prefix, set.name) end else if err then @@ -1100,6 +1103,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) -- at least (10 * (1 - 0.25)) = 8 trains local max_len = math.max(lua_util.unpack(lua_util.values(lens))) + local min_len = math.min(lua_util.unpack(lua_util.values(lens))) if rule.train.learn_type == 'balanced' then local len_bias_check_pred = function(_, l) @@ -1117,7 +1121,7 @@ local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) end else -- Probabilistic mode, just ensure that at least one vector is okay - if max_len >= rule.train.max_trains then + if min_len > 0 and max_len >= rule.train.max_trains then rspamd_logger.debugm(N, rspamd_config, 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', ann_key, lens, rule.train.max_trains, what) |