end
end
-local function train_fann(rule, _, ev_base, elt)
+local function train_fann(rule, _, ev_base, elt, worker)
local spam_elts = {}
local ham_elts = {}
elt = tostring(elt)
end
end
+ local function ann_trained_torch(err, data)
+ rule.learning_spawned = false
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
+ prefix, err)
+ rspamd_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_unlock_cb, --callback
+ 'DEL', -- command
+ {prefix .. '_locked'}
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s',
+ prefix)
+ local ann_data
+ local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
+ ann_data = rspamd_util.zstd_compress(f:storage():string())
+ fanns[elt].fann_train = f:readObject()
+
+ fanns[elt].version = fanns[elt].version + 1
+ fanns[elt].fann = fanns[elt].fann_train
+ fanns[elt].fann_train = nil
+ rspamd_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_save_cb, --callback
+ 'EVALSHA', -- command
+ {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+ )
+ end
+ end
+
local function redis_ham_cb(err, data)
if err or type(data) ~= 'table' then
rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
end, data))
-- Now we need to join inputs and create the appropriate test vectors
- local inputs = {}
- local outputs = {}
-
local n = rspamd_config:get_symbols_count() +
meta_functions.rspamd_count_metatokens()
local filt = function(elts)
return #elts == n
end
- fun.each(function(spam_sample, ham_sample)
- table.insert(inputs, spam_sample)
- table.insert(outputs, {1.0})
- table.insert(inputs, ham_sample)
- table.insert(outputs, {-1.0})
- end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
-
-- Now we can train fann
if not fanns[elt] or not fanns[elt].fann_train then
-- Create fann if it does not exist
create_train_fann(rule, n, elt)
end
- if #inputs < rule.max_trains / 2 then
+ if #spam_elts + #ham_elts < rule.max_trains / 2 then
-- Invalidate ANN as it is definitely invalid
local function redis_invalidate_cb(_err, _data)
if _err then
{redis_locked_invalidate_sha, 1, prefix}
)
else
- rule.learning_spawned = true
- rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
- fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
- ev_base, {
- max_epochs = rule.train.max_epoch,
- desired_mse = rule.train.mse
- })
+ if torch then
+ -- For torch we do not need to mix samples as they would be flushed
+ local dataset = {}
+ fun.each(function(s)
+ table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})})
+ end, spam_elts)
+ fun.each(function(s)
+ table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})})
+ end, ham_elts)
+ -- Needed for torch
+ dataset.size = function(tbl) return #tbl end
+
+ local function train_torch()
+ local criterion = nn.MSECriterion()
+ local trainer = nn.StochasticGradient(fanns[elt].fann_train,
+ criterion)
+ trainer.learning_rate = 0.01
+ trainer.hookIteration = function(self, iteration, currentError)
+ rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
+ iteration, currentError)
+ end
+
+ trainer:train(dataset)
+ local out = torch.MemoryFile()
+ out:writeObject(fanns[elt].fann_train)
+ local st = out:storage():string()
+ return out
+ end
+
+ worker:spawn_process{
+ func = train_torch,
+ on_complete = ann_trained_torch,
+ }
+ else
+ local inputs = {}
+ local outputs = {}
+
+ fun.each(function(spam_sample, ham_sample)
+ table.insert(inputs, spam_sample)
+ table.insert(outputs, {1.0})
+ table.insert(inputs, ham_sample)
+ table.insert(outputs, {-1.0})
+ end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
+ rule.learning_spawned = true
+ rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
+ fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
+ ev_base, {
+ max_epochs = rule.train.max_epoch,
+ desired_mse = rule.train.mse
+ })
+ end
+
end
end
end
)
end
-local function maybe_train_fanns(rule, cfg, ev_base)
+local function maybe_train_fanns(rule, cfg, ev_base, worker)
local function members_cb(err, data)
if err then
rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
prefix, tonumber(_data), rule.max_trains)
- train_fann(rule, cfg, ev_base, elt)
+ train_fann(rule, cfg, ev_base, elt, worker)
end
end
end
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)
- return maybe_train_fanns(rule, cfg, ev_base)
+ return maybe_train_fanns(rule, cfg, ev_base, worker)
end)
end
end)