From: Vsevolod Stakhov Date: Fri, 18 Oct 2019 16:18:26 +0000 (+0100) Subject: [Minor] Neural: Add nan check and extensive logging X-Git-Tag: 2.1~68 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=43205d7e865312939ed452223442e5128b0e2a6a;p=rspamd.git [Minor] Neural: Add nan check and extensive logging --- diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 1ff1f40d7..e6ffe41be 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -564,7 +564,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve local inputs, outputs = {}, {} -- Used to show sparsed vectors in a convenient format (for debugging only) - --[[ local function debug_vec(t) local ret = {} for i,v in ipairs(t) do @@ -575,7 +574,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve return ret end - ]]-- -- Make training set by joining vectors -- KANN automatically shuffles those samples @@ -595,22 +593,44 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve -- Called in child process local function train() local log_thresh = rule.train.max_iterations / 10 - train_ann:train1(inputs, outputs, { - lr = rule.train.learning_rate, - max_epoch = rule.train.max_iterations, - cb = function(iter, train_cost, _) - if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then - rspamd_logger.infox(rspamd_config, - "ANN %s:%s: learned from %s redis key in %s iterations, error: %s", + local seen_nan = false + + local function train_cb(iter, train_cost, value_cost) + if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then + if train_cost ~= train_cost and not seen_nan then + -- We have nan :( try to log lot's of stuff to dig into a problem + seen_nan = true + rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', rule.prefix, set.name, - ann_key, - iter, train_cost) + value_cost) + for i,e in ipairs(inputs) do + lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', + debug_vec(e), outputs[i][1]) + end end + + rspamd_logger.infox(rspamd_config, + "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", + rule.prefix, set.name, + ann_key, + iter, + train_cost, + value_cost) end + end + + train_ann:train1(inputs, outputs, { + lr = rule.train.learning_rate, + max_epoch = rule.train.max_iterations, + cb = train_cb, }) - local out = train_ann:save() - return out + if not seen_nan then + local out = train_ann:save() + return out + else + return nil + end end set.learning_spawned = true