From ec7c781f2f263913ebc9d7caf6c9208bda9071ba Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sun, 3 Sep 2017 12:20:36 +0100 Subject: [PATCH] [Feature] Implement torch based ANN learning --- src/plugins/lua/fann_redis.lua | 116 ++++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 22 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index b0cbdefab..5c444e74c 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -579,7 +579,7 @@ local function fann_train_callback(rule, task, score, required_score, id) end end -local function train_fann(rule, _, ev_base, elt) +local function train_fann(rule, _, ev_base, elt, worker) local spam_elts = {} local ham_elts = {} elt = tostring(elt) @@ -652,6 +652,43 @@ local function train_fann(rule, _, ev_base, elt) end end + local function ann_trained_torch(err, data) + rule.learning_spawned = false + if err then + rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', + prefix, err) + rspamd_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_unlock_cb, --callback + 'DEL', -- command + {prefix .. '_locked'} + ) + else + rspamd_logger.infox(rspamd_config, 'trained ANN %s', + prefix) + local ann_data + local f = torch.MemoryFile(torch.CharStorage():string(tostring(data))) + ann_data = rspamd_util.zstd_compress(f:storage():string()) + fanns[elt].fann_train = f:readObject() + + fanns[elt].version = fanns[elt].version + 1 + fanns[elt].fann = fanns[elt].fann_train + fanns[elt].fann_train = nil + rspamd_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_save_cb, --callback + 'EVALSHA', -- command + {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)} + ) + end + end + local function redis_ham_cb(err, data) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', @@ -673,29 +710,19 @@ local function train_fann(rule, _, ev_base, elt) end, data)) -- Now we need to join inputs and create the appropriate test vectors - local inputs = {} - local outputs = {} - local n = rspamd_config:get_symbols_count() + meta_functions.rspamd_count_metatokens() local filt = function(elts) return #elts == n end - fun.each(function(spam_sample, ham_sample) - table.insert(inputs, spam_sample) - table.insert(outputs, {1.0}) - table.insert(inputs, ham_sample) - table.insert(outputs, {-1.0}) - end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts))) - -- Now we can train fann if not fanns[elt] or not fanns[elt].fann_train then -- Create fann if it does not exist create_train_fann(rule, n, elt) end - if #inputs < rule.max_trains / 2 then + if #spam_elts + #ham_elts < rule.max_trains / 2 then -- Invalidate ANN as it is definitely invalid local function redis_invalidate_cb(_err, _data) if _err then @@ -717,13 +744,58 @@ local function train_fann(rule, _, ev_base, elt) {redis_locked_invalidate_sha, 1, prefix} ) else - rule.learning_spawned = true - rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) - fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, - ev_base, { - max_epochs = rule.train.max_epoch, - desired_mse = rule.train.mse - }) + if torch then + -- For torch we do not need to mix samples as they would be flushed + local dataset = {} + fun.each(function(s) + table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})}) + end, spam_elts) + fun.each(function(s) + table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})}) + end, ham_elts) + -- Needed for torch + dataset.size = function(tbl) return #tbl end + + local function train_torch() + local criterion = nn.MSECriterion() + local trainer = nn.StochasticGradient(fanns[elt].fann_train, + criterion) + trainer.learning_rate = 0.01 + trainer.hookIteration = function(self, iteration, currentError) + rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s", + iteration, currentError) + end + + trainer:train(dataset) + local out = torch.MemoryFile() + out:writeObject(fanns[elt].fann_train) + local st = out:storage():string() + return out + end + + worker:spawn_process{ + func = train_torch, + on_complete = ann_trained_torch, + } + else + local inputs = {} + local outputs = {} + + fun.each(function(spam_sample, ham_sample) + table.insert(inputs, spam_sample) + table.insert(outputs, {1.0}) + table.insert(inputs, ham_sample) + table.insert(outputs, {-1.0}) + end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts))) + rule.learning_spawned = true + rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) + fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, + ev_base, { + max_epochs = rule.train.max_epoch, + desired_mse = rule.train.mse + }) + end + end end end @@ -827,7 +899,7 @@ local function train_fann(rule, _, ev_base, elt) ) end -local function maybe_train_fanns(rule, cfg, ev_base) +local function maybe_train_fanns(rule, cfg, ev_base, worker) local function members_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) @@ -844,7 +916,7 @@ local function maybe_train_fanns(rule, cfg, ev_base) rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)', prefix, tonumber(_data), rule.max_trains) - train_fann(rule, cfg, ev_base, elt) + train_fann(rule, cfg, ev_base, elt, worker) end end end @@ -1032,7 +1104,7 @@ else -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, function(_, _) - return maybe_train_fanns(rule, cfg, ev_base) + return maybe_train_fanns(rule, cfg, ev_base, worker) end) end end) -- 2.39.5