local inputs, outputs = {}, {}
-- Used to show sparsed vectors in a convenient format (for debugging only)
- --[[
local function debug_vec(t)
local ret = {}
for i,v in ipairs(t) do
return ret
end
- ]]--
-- Make training set by joining vectors
-- KANN automatically shuffles those samples
-- Called in child process
local function train()
local log_thresh = rule.train.max_iterations / 10
- train_ann:train1(inputs, outputs, {
- lr = rule.train.learning_rate,
- max_epoch = rule.train.max_iterations,
- cb = function(iter, train_cost, _)
- if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
- rspamd_logger.infox(rspamd_config,
- "ANN %s:%s: learned from %s redis key in %s iterations, error: %s",
+ local seen_nan = false
+
+ local function train_cb(iter, train_cost, value_cost)
+ if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then
+ if train_cost ~= train_cost and not seen_nan then
+ -- We have nan :( try to log lot's of stuff to dig into a problem
+ seen_nan = true
+ rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
rule.prefix, set.name,
- ann_key,
- iter, train_cost)
+ value_cost)
+ for i,e in ipairs(inputs) do
+ lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+ debug_vec(e), outputs[i][1])
+ end
end
+
+ rspamd_logger.infox(rspamd_config,
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+ rule.prefix, set.name,
+ ann_key,
+ iter,
+ train_cost,
+ value_cost)
end
+ end
+
+ train_ann:train1(inputs, outputs, {
+ lr = rule.train.learning_rate,
+ max_epoch = rule.train.max_iterations,
+ cb = train_cb,
})
- local out = train_ann:save()
- return out
+ if not seen_nan then
+ local out = train_ann:save()
+ return out
+ else
+ return nil
+ end
end
set.learning_spawned = true