aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/plugins/neural.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-12-18 16:06:53 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2020-12-18 16:07:19 +0000
commit7ec92c421a5d1da6d26b6f2cdfed3cc585481155 (patch)
tree4011d641bea14f6e1a779f9fa350a1d6ba6a55b5 /lualib/plugins/neural.lua
parentf589cc476697815f75bcb34149ed7a4c175cc2c2 (diff)
downloadrspamd-7ec92c421a5d1da6d26b6f2cdfed3cc585481155.tar.gz
rspamd-7ec92c421a5d1da6d26b6f2cdfed3cc585481155.zip
[Feature] Neural: Move PCA learning to a subprocess
Diffstat (limited to 'lualib/plugins/neural.lua')
-rw-r--r--lualib/plugins/neural.lua50
1 files changed, 39 insertions, 11 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index 4d4c44b5d..6f82089a4 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -23,6 +23,7 @@ local rspamd_kann = require "rspamd_kann"
local rspamd_logger = require "rspamd_logger"
local rspamd_tensor = require "rspamd_tensor"
local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
local N = 'neural'
@@ -464,12 +465,22 @@ local function spawn_train(params)
lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
params.rule.prefix, params.set.name)
+ local pca
+ if params.rule.max_inputs then
+ -- Train PCA in the main process, presumably it is not that long
+ lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ pca = learn_pca(inputs, params.rule.max_inputs)
+ end
+
+ lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
local ret,err = pcall(train_ann.train1, train_ann,
inputs, outputs, {
lr = params.rule.train.learning_rate,
max_epoch = params.rule.train.max_iterations,
cb = train_cb,
- pca = (params.set.ann or {}).pca
+ pca = pca
})
if not ret then
@@ -477,11 +488,26 @@ local function spawn_train(params)
params.rule.prefix, params.set.name, err)
return nil
+ else
+ lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
end
if not seen_nan then
- local out = train_ann:save()
- return out
+ -- Convert to strings as ucl cannot rspamd_text properly
+ local pca_data
+ if pca then
+ pca_data = tostring(pca:save())
+ end
+ local out = {
+ ann_data = tostring(train_ann:save()),
+ pca_data = pca_data,
+ }
+
+ local final_data = ucl.to_format(out, 'msgpack')
+ lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
+ params.rule.prefix, params.set.name, #final_data)
+ return final_data
else
return nil
end
@@ -523,15 +549,20 @@ local function spawn_train(params)
{params.ann_key, 'lock'}
)
else
- local ann_data = rspamd_util.zstd_compress(data)
- local pca_data
+ local parser = ucl.parser()
+ local ok, parse_err = parser:parse_text(data, 'msgpack')
+ assert(ok, parse_err)
+ local parsed = parser:get_object()
+ local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
+ local pca_data = parsed.pca_data
fill_set_ann(params.set, params.ann_key)
- if params.set.ann.pca then
- pca_data = rspamd_util.zstd_compress(params.set.ann.pca:save())
+ if pca_data then
+ params.set.ann.pca = rspamd_tensor.load(pca_data)
+ pca_data = rspamd_util.zstd_compress(pca_data)
end
-- Deserialise ANN from the child process
- ann_trained = rspamd_kann.load(data)
+ ann_trained = rspamd_kann.load(parsed.ann_data)
local version = (params.set.ann.version or 0) + 1
params.set.ann.version = version
params.set.ann.ann = ann_trained
@@ -545,7 +576,6 @@ local function spawn_train(params)
version = version
}
- local ucl = require "ucl"
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
rspamd_logger.infox(rspamd_config,
@@ -572,8 +602,6 @@ local function spawn_train(params)
if params.rule.max_inputs then
fill_set_ann(params.set, params.ann_key)
- -- Train PCA in the main process, presumably it is not that long
- params.set.ann.pca = learn_pca(inputs, params.rule.max_inputs)
end
params.worker:spawn_process{