From 1d26ec302293d9eeeafb5e14f4d3a0d73c126f4f Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 27 Aug 2020 22:00:38 +0100 Subject: [PATCH] [Project] Neural: Implement PCA in learning --- src/plugins/lua/neural.lua | 62 ++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 91caa8e07..5b4ff8b3b 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -639,10 +639,17 @@ local function register_lock_extender(rule, set, ev_base, ann_key) end -- This is an utility function for PCA training -local function fill_scatter(inputs, meanv) +local function fill_scatter(inputs) local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs) local row_len = #inputs[1] + if type(inputs) == 'table' then + -- Convert to a tensor + inputs = rspamd_tensor.fromtable(inputs) + end + + local meanv = inputs:mean() + for i=1,row_len do local col = rspamd_tensor.new(1, #inputs) for j=1,#inputs do @@ -663,8 +670,7 @@ end -- This function takes all inputs, applies PCA transformation and returns the final -- PCA matrix as rspamd_tensor local function learn_pca(inputs, max_inputs) - local meanv = inputs:mean() - local scatter_matrix = fill_scatter(inputs, meanv) + local scatter_matrix = fill_scatter(inputs) local eigenvals = scatter_matrix:eigen() -- scatter matrix is not filled with eigenvectors lua_util.debugm(N, 'eigenvalues: %s', eigenvals) @@ -676,6 +682,19 @@ local function learn_pca(inputs, max_inputs) return w end +-- Fills ANN data for a specific settings element +local function fill_set_ann(set, ann_key) + if not set.ann then + set.ann = { + symbols = set.symbols, + distance = 0, + digest = set.digest, + redis_key = ann_key, + version = 0, + } + end +end + -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec) -- Check training data sanity @@ -684,7 +703,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve meta_functions.rspamd_count_metatokens() -- Now we can train ann - local train_ann = create_ann(n, 3, rule) + local train_ann = create_ann(rule.max_inputs or n, 3, rule) if #ham_vec + #spam_vec < rule.train.max_trains / 2 then -- Invalidate ANN as it is definitely invalid @@ -749,12 +768,23 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve end end - train_ann:train1(inputs, outputs, { - lr = rule.train.learning_rate, - max_epoch = rule.train.max_iterations, - cb = train_cb, - pca = set.ann.pca - }) + lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", + rule.prefix, set.name) + + local ret,err = pcall(train_ann.train1, train_ann, + inputs, outputs, { + lr = rule.train.learning_rate, + max_epoch = rule.train.max_iterations, + cb = train_cb, + pca = set.ann.pca + }) + + if not ret then + rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", + rule.prefix, set.name, err) + + return nil + end if not seen_nan then local out = train_ann:save() @@ -806,14 +836,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve if set.ann.pca then pca_data = rspamd_util.zstd_compress(set.ann.pca:save()) end - if not set.ann then - set.ann = { - symbols = set.symbols, - distance = 0, - digest = set.digest, - redis_key = ann_key, - } - end + fill_set_ann(set, ann_key) -- Deserialise ANN from the child process ann_trained = rspamd_kann.load(data) local version = (set.ann.version or 0) + 1 @@ -852,6 +875,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve end if rule.max_inputs then + fill_set_ann(set, ann_key) -- Train PCA in the main process, presumably it is not that long set.ann.pca = learn_pca(inputs, rule.max_inputs) end @@ -1045,7 +1069,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', rule.prefix, set.name, ann_key) end - if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then + if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then -- PCA table local _err,pca_data = rspamd_util.zstd_decompress(data[2]) if pca_data then -- 2.39.5