diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-12-18 16:06:53 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-12-18 16:07:19 +0000 |
commit | 7ec92c421a5d1da6d26b6f2cdfed3cc585481155 (patch) | |
tree | 4011d641bea14f6e1a779f9fa350a1d6ba6a55b5 /lualib/plugins | |
parent | f589cc476697815f75bcb34149ed7a4c175cc2c2 (diff) | |
download | rspamd-7ec92c421a5d1da6d26b6f2cdfed3cc585481155.tar.gz rspamd-7ec92c421a5d1da6d26b6f2cdfed3cc585481155.zip |
[Feature] Neural: Move PCA learning to a subprocess
Diffstat (limited to 'lualib/plugins')
-rw-r--r-- | lualib/plugins/neural.lua | 50 |
1 files changed, 39 insertions, 11 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 4d4c44b5d..6f82089a4 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -23,6 +23,7 @@ local rspamd_kann = require "rspamd_kann" local rspamd_logger = require "rspamd_logger" local rspamd_tensor = require "rspamd_tensor" local rspamd_util = require "rspamd_util" +local ucl = require "ucl" local N = 'neural' @@ -464,12 +465,22 @@ local function spawn_train(params) lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", params.rule.prefix, params.set.name) + local pca + if params.rule.max_inputs then + -- Train PCA in the main process, presumably it is not that long + lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s", + params.rule.prefix, params.set.name) + pca = learn_pca(inputs, params.rule.max_inputs) + end + + lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s", + params.rule.prefix, params.set.name) local ret,err = pcall(train_ann.train1, train_ann, inputs, outputs, { lr = params.rule.train.learning_rate, max_epoch = params.rule.train.max_iterations, cb = train_cb, - pca = (params.set.ann or {}).pca + pca = pca }) if not ret then @@ -477,11 +488,26 @@ local function spawn_train(params) params.rule.prefix, params.set.name, err) return nil + else + lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s", + params.rule.prefix, params.set.name) end if not seen_nan then - local out = train_ann:save() - return out + -- Convert to strings as ucl cannot rspamd_text properly + local pca_data + if pca then + pca_data = tostring(pca:save()) + end + local out = { + ann_data = tostring(train_ann:save()), + pca_data = pca_data, + } + + local final_data = ucl.to_format(out, 'msgpack') + lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes", + params.rule.prefix, params.set.name, #final_data) + return final_data else return nil end @@ -523,15 +549,20 @@ local function spawn_train(params) {params.ann_key, 'lock'} ) else - local ann_data = rspamd_util.zstd_compress(data) - local pca_data + local parser = ucl.parser() + local ok, parse_err = parser:parse_text(data, 'msgpack') + assert(ok, parse_err) + local parsed = parser:get_object() + local ann_data = rspamd_util.zstd_compress(parsed.ann_data) + local pca_data = parsed.pca_data fill_set_ann(params.set, params.ann_key) - if params.set.ann.pca then - pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save()) + if pca_data then + params.set.ann.pca = rspamd_tensor.load(pca_data) + pca_data = rspamd_util.zstd_compress(pca_data) end -- Deserialise ANN from the child process - ann_trained = rspamd_kann.load(data) + ann_trained = rspamd_kann.load(parsed.ann_data) local version = (params.set.ann.version or 0) + 1 params.set.ann.version = version params.set.ann.ann = ann_trained @@ -545,7 +576,6 @@ local function spawn_train(params) version = version } - local ucl = require "ucl" local profile_serialized = ucl.to_format(profile, 'json-compact', true) rspamd_logger.infox(rspamd_config, @@ -572,8 +602,6 @@ local function spawn_train(params) if params.rule.max_inputs then fill_set_ann(params.set, params.ann_key) - -- Train PCA in the main process, presumably it is not that long - params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs) end params.worker:spawn_process{ |