end
-- This is an utility function for PCA training
-local function fill_scatter(inputs, meanv)
+local function fill_scatter(inputs)
local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
local row_len = #inputs[1]
+ if type(inputs) == 'table' then
+ -- Convert to a tensor
+ inputs = rspamd_tensor.fromtable(inputs)
+ end
+
+ local meanv = inputs:mean()
+
for i=1,row_len do
local col = rspamd_tensor.new(1, #inputs)
for j=1,#inputs do
-- This function takes all inputs, applies PCA transformation and returns the final
-- PCA matrix as rspamd_tensor
local function learn_pca(inputs, max_inputs)
- local meanv = inputs:mean()
- local scatter_matrix = fill_scatter(inputs, meanv)
+ local scatter_matrix = fill_scatter(inputs)
local eigenvals = scatter_matrix:eigen()
-- scatter matrix is not filled with eigenvectors
lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
return w
end
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+ if not set.ann then
+ set.ann = {
+ symbols = set.symbols,
+ distance = 0,
+ digest = set.digest,
+ redis_key = ann_key,
+ version = 0,
+ }
+ end
+end
+
-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_vec)
-- Check training data sanity
meta_functions.rspamd_count_metatokens()
-- Now we can train ann
- local train_ann = create_ann(n, 3, rule)
+ local train_ann = create_ann(rule.max_inputs or n, 3, rule)
if #ham_vec + #spam_vec < rule.train.max_trains / 2 then
-- Invalidate ANN as it is definitely invalid
end
end
- train_ann:train1(inputs, outputs, {
- lr = rule.train.learning_rate,
- max_epoch = rule.train.max_iterations,
- cb = train_cb,
- pca = set.ann.pca
- })
+ lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+ rule.prefix, set.name)
+
+ local ret,err = pcall(train_ann.train1, train_ann,
+ inputs, outputs, {
+ lr = rule.train.learning_rate,
+ max_epoch = rule.train.max_iterations,
+ cb = train_cb,
+ pca = set.ann.pca
+ })
+
+ if not ret then
+ rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+ rule.prefix, set.name, err)
+
+ return nil
+ end
if not seen_nan then
local out = train_ann:save()
if set.ann.pca then
pca_data = rspamd_util.zstd_compress(set.ann.pca:save())
end
- if not set.ann then
- set.ann = {
- symbols = set.symbols,
- distance = 0,
- digest = set.digest,
- redis_key = ann_key,
- }
- end
+ fill_set_ann(set, ann_key)
-- Deserialise ANN from the child process
ann_trained = rspamd_kann.load(data)
local version = (set.ann.version or 0) + 1
end
if rule.max_inputs then
+ fill_set_ann(set, ann_key)
-- Train PCA in the main process, presumably it is not that long
set.ann.pca = learn_pca(inputs, rule.max_inputs)
end
lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
rule.prefix, set.name, ann_key)
end
- if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+ if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
-- PCA table
local _err,pca_data = rspamd_util.zstd_decompress(data[2])
if pca_data then