|
|
@@ -639,10 +639,17 @@ local function register_lock_extender(rule, set, ev_base, ann_key) |
|
|
|
end |
|
|
|
|
|
|
|
-- This is an utility function for PCA training |
|
|
|
local function fill_scatter(inputs, meanv) |
|
|
|
local function fill_scatter(inputs) |
|
|
|
local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs) |
|
|
|
local row_len = #inputs[1] |
|
|
|
|
|
|
|
if type(inputs) == 'table' then |
|
|
|
-- Convert to a tensor |
|
|
|
inputs = rspamd_tensor.fromtable(inputs) |
|
|
|
end |
|
|
|
|
|
|
|
local meanv = inputs:mean() |
|
|
|
|
|
|
|
for i=1,row_len do |
|
|
|
local col = rspamd_tensor.new(1, #inputs) |
|
|
|
for j=1,#inputs do |
|
|
@@ -663,8 +670,7 @@ end |
|
|
|
-- This function takes all inputs, applies PCA transformation and returns the final |
|
|
|
-- PCA matrix as rspamd_tensor |
|
|
|
local function learn_pca(inputs, max_inputs) |
|
|
|
local meanv = inputs:mean() |
|
|
|
local scatter_matrix = fill_scatter(inputs, meanv) |
|
|
|
local scatter_matrix = fill_scatter(inputs) |
|
|
|
local eigenvals = scatter_matrix:eigen() |
|
|
|
-- scatter matrix is not filled with eigenvectors |
|
|
|
lua_util.debugm(N, 'eigenvalues: %s', eigenvals) |
|
|
@@ -676,6 +682,19 @@ local function learn_pca(inputs, max_inputs) |
|
|
|
return w |
|
|
|
end |
|
|
|
|
|
|
|
-- Fills ANN data for a specific settings element |
|
|
|
local function fill_set_ann(set, ann_key) |
|
|
|
if not set.ann then |
|
|
|
set.ann = { |
|
|
|
symbols = set.symbols, |
|
|
|
distance = 0, |
|
|
|
digest = set.digest, |
|
|
|
redis_key = ann_key, |
|
|
|
version = 0, |
|
|
|
} |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis |
|
|
|
local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec) |
|
|
|
-- Check training data sanity |
|
|
@@ -684,7 +703,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve |
|
|
|
meta_functions.rspamd_count_metatokens() |
|
|
|
|
|
|
|
-- Now we can train ann |
|
|
|
local train_ann = create_ann(n, 3, rule) |
|
|
|
local train_ann = create_ann(rule.max_inputs or n, 3, rule) |
|
|
|
|
|
|
|
if #ham_vec + #spam_vec < rule.train.max_trains / 2 then |
|
|
|
-- Invalidate ANN as it is definitely invalid |
|
|
@@ -749,12 +768,23 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve |
|
|
|
end |
|
|
|
end |
|
|
|
|
|
|
|
train_ann:train1(inputs, outputs, { |
|
|
|
lr = rule.train.learning_rate, |
|
|
|
max_epoch = rule.train.max_iterations, |
|
|
|
cb = train_cb, |
|
|
|
pca = set.ann.pca |
|
|
|
}) |
|
|
|
lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", |
|
|
|
rule.prefix, set.name) |
|
|
|
|
|
|
|
local ret,err = pcall(train_ann.train1, train_ann, |
|
|
|
inputs, outputs, { |
|
|
|
lr = rule.train.learning_rate, |
|
|
|
max_epoch = rule.train.max_iterations, |
|
|
|
cb = train_cb, |
|
|
|
pca = set.ann.pca |
|
|
|
}) |
|
|
|
|
|
|
|
if not ret then |
|
|
|
rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", |
|
|
|
rule.prefix, set.name, err) |
|
|
|
|
|
|
|
return nil |
|
|
|
end |
|
|
|
|
|
|
|
if not seen_nan then |
|
|
|
local out = train_ann:save() |
|
|
@@ -806,14 +836,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve |
|
|
|
if set.ann.pca then |
|
|
|
pca_data = rspamd_util.zstd_compress(set.ann.pca:save()) |
|
|
|
end |
|
|
|
if not set.ann then |
|
|
|
set.ann = { |
|
|
|
symbols = set.symbols, |
|
|
|
distance = 0, |
|
|
|
digest = set.digest, |
|
|
|
redis_key = ann_key, |
|
|
|
} |
|
|
|
end |
|
|
|
fill_set_ann(set, ann_key) |
|
|
|
-- Deserialise ANN from the child process |
|
|
|
ann_trained = rspamd_kann.load(data) |
|
|
|
local version = (set.ann.version or 0) + 1 |
|
|
@@ -852,6 +875,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve |
|
|
|
end |
|
|
|
|
|
|
|
if rule.max_inputs then |
|
|
|
fill_set_ann(set, ann_key) |
|
|
|
-- Train PCA in the main process, presumably it is not that long |
|
|
|
set.ann.pca = learn_pca(inputs, rule.max_inputs) |
|
|
|
end |
|
|
@@ -1045,7 +1069,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) |
|
|
|
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', |
|
|
|
rule.prefix, set.name, ann_key) |
|
|
|
end |
|
|
|
if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then |
|
|
|
if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then |
|
|
|
-- PCA table |
|
|
|
local _err,pca_data = rspamd_util.zstd_decompress(data[2]) |
|
|
|
if pca_data then |