From 052fb86191e4f5137882103f7a1d3574401b75be Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 15 Nov 2016 15:48:36 +0000 Subject: [Fix] More fixes to fann_redis --- src/plugins/lua/fann_redis.lua | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 08216fe5e..ab4d74afa 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -404,6 +404,7 @@ local function train_fann(_, ev_base, elt) end local function ann_trained(errcode, errmsg, train_mse) + learning_spawned = false if errcode ~= 0 then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', fann_prefix .. elt, errmsg) @@ -418,7 +419,6 @@ local function train_fann(_, ev_base, elt) else rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s', fann_prefix .. elt, train_mse) - learning_spawned = false local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data()) fanns[elt].version = fanns[elt].version + 1 fanns[elt].fann = fanns[elt].fann_train @@ -457,10 +457,10 @@ local function train_fann(_, ev_base, elt) local inputs = {} local outputs = {} - fun.each(function(sample) - table.insert(inputs, fun.totable(sample[1])) + fun.each(function(spam_sample, ham_sample) + table.insert(inputs, fun.totable(spam_sample)) table.insert(outputs, {1.0}) - table.insert(inputs, fun.totable(sample[2])) + table.insert(inputs, fun.totable(ham_sample)) table.insert(outputs, {-1.0}) end, fun.zip(spam_elts, ham_elts)) @@ -472,6 +472,7 @@ local function train_fann(_, ev_base, elt) end learning_spawned = true + rspamd_logger.infox(rspamd_config, 'start learning ANN %s', elt) fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base, {max_epochs = max_epoch, desired_mse = mse}) end @@ -520,9 +521,15 @@ local function train_fann(_, ev_base, elt) 'LRANGE', -- command {fann_prefix .. elt .. '_spam', '0', '-1'} ) + rspamd_logger.infox(rspamd_config, 'lock ANN %s for learning', elt) + else + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', elt) end end - if learning_spawned then return end + if learning_spawned then + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN') + return + end redis_make_request(ev_base, rspamd_config, nil, @@ -547,7 +554,7 @@ local function maybe_train_fanns(cfg, ev_base) elseif _data and type(_data) == 'number' or type(_data) == 'string' then if tonumber(_data) and tonumber(_data) > max_trains then rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)', - tonumber(_data), max_trains) + elt, tonumber(_data), max_trains) train_fann(cfg, ev_base, elt) end end -- cgit v1.2.3