Browse Source

[Project] Neural: Implement PCA in learning

tags/2.6
Vsevolod Stakhov 3 years ago
parent
commit
1d26ec3022
1 changed files with 43 additions and 19 deletions
  1. 43
    19
      src/plugins/lua/neural.lua

+ 43
- 19
src/plugins/lua/neural.lua View File

@@ -639,10 +639,17 @@ local function register_lock_extender(rule, set, ev_base, ann_key)
end

-- This is an utility function for PCA training
local function fill_scatter(inputs, meanv)
local function fill_scatter(inputs)
local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
local row_len = #inputs[1]

if type(inputs) == 'table' then
-- Convert to a tensor
inputs = rspamd_tensor.fromtable(inputs)
end

local meanv = inputs:mean()

for i=1,row_len do
local col = rspamd_tensor.new(1, #inputs)
for j=1,#inputs do
@@ -663,8 +670,7 @@ end
-- This function takes all inputs, applies PCA transformation and returns the final
-- PCA matrix as rspamd_tensor
local function learn_pca(inputs, max_inputs)
local meanv = inputs:mean()
local scatter_matrix = fill_scatter(inputs, meanv)
local scatter_matrix = fill_scatter(inputs)
local eigenvals = scatter_matrix:eigen()
-- scatter matrix is not filled with eigenvectors
lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
@@ -676,6 +682,19 @@ local function learn_pca(inputs, max_inputs)
return w
end

-- Fills ANN data for a specific settings element
local function fill_set_ann(set, ann_key)
if not set.ann then
set.ann = {
symbols = set.symbols,
distance = 0,
digest = set.digest,
redis_key = ann_key,
version = 0,
}
end
end

-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
-- Check training data sanity
@@ -684,7 +703,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
meta_functions.rspamd_count_metatokens()

-- Now we can train ann
local train_ann = create_ann(n, 3, rule)
local train_ann = create_ann(rule.max_inputs or n, 3, rule)

if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
-- Invalidate ANN as it is definitely invalid
@@ -749,12 +768,23 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
end
end

train_ann:train1(inputs, outputs, {
lr = rule.train.learning_rate,
max_epoch = rule.train.max_iterations,
cb = train_cb,
pca = set.ann.pca
})
lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
rule.prefix, set.name)

local ret,err = pcall(train_ann.train1, train_ann,
inputs, outputs, {
lr = rule.train.learning_rate,
max_epoch = rule.train.max_iterations,
cb = train_cb,
pca = set.ann.pca
})

if not ret then
rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
rule.prefix, set.name, err)

return nil
end

if not seen_nan then
local out = train_ann:save()
@@ -806,14 +836,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
if set.ann.pca then
pca_data = rspamd_util.zstd_compress(set.ann.pca:save())
end
if not set.ann then
set.ann = {
symbols = set.symbols,
distance = 0,
digest = set.digest,
redis_key = ann_key,
}
end
fill_set_ann(set, ann_key)
-- Deserialise ANN from the child process
ann_trained = rspamd_kann.load(data)
local version = (set.ann.version or 0) + 1
@@ -852,6 +875,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
end

if rule.max_inputs then
fill_set_ann(set, ann_key)
-- Train PCA in the main process, presumably it is not that long
set.ann.pca = learn_pca(inputs, rule.max_inputs)
end
@@ -1045,7 +1069,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
-- PCA table
local _err,pca_data = rspamd_util.zstd_decompress(data[2])
if pca_data then

Loading…
Cancel
Save