From 6d11758e98e2adac29897cb45b7a243625d9b761 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 15 Jul 2019 16:40:47 +0100 Subject: [PATCH] [Fix] Neural: Another bunch of fixes --- src/plugins/lua/neural.lua | 45 +++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 2e4c8e7cc..a68d6f83a 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -541,6 +541,20 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve else local inputs, outputs = {}, {} + -- Used to show sparsed vectors in a convenient format (for debugging only) + --[[ + local function debug_vec(t) + local ret = {} + for i,v in ipairs(t) do + if v ~= 0 then + ret[#ret + 1] = string.format('%d=%.2f', i, v) + end + end + + return ret + end + ]]-- + -- Make training set by joining vectors -- KANN automatically shuffles those samples -- 1.0 is used for spam and -1.0 is used for ham @@ -548,21 +562,26 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve for _,e in ipairs(spam_vec) do inputs[#inputs + 1] = e outputs[#outputs + 1] = {1.0} + --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) end for _,e in ipairs(ham_vec) do inputs[#inputs + 1] = e outputs[#outputs + 1] = {-1.0} + --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) end -- Called in child process local function train() + local log_thresh = rule.train.max_iterations / 10 train_ann:train1(inputs, outputs, { lr = rule.train.learning_rate, max_epoch = rule.train.max_iterations, cb = function(iter, train_cost, _) - if math.floor(iter / rule.train.max_iterations * 10) % 10 == 0 then - rspamd_logger.infox(rspamd_config, "ANN %s:%s: learned %s iterations, error: %s", + if (iter * (rule.train.max_iterations / log_thresh)) % (rule.train.max_iterations) == 0 then + rspamd_logger.infox(rspamd_config, + "ANN %s:%s: learned from %s redis key in %s iterations, error: %s", rule.prefix, set.name, + ann_key, iter, train_cost) end end @@ -589,7 +608,7 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve ) else rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', - rule.prefix, set.name, ann_key) + rule.prefix, set.name, set.ann.redis_key) end end @@ -608,8 +627,6 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve {ann_key, 'lock'} ) else - rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes; redis key: %s', - rule.prefix, set.name, #data, ann_key) local ann_data = rspamd_util.zstd_compress(data) if not set.ann then set.ann = { @@ -637,6 +654,10 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve local ucl = require "ucl" local profile_serialized = ucl.to_format(profile, 'json-compact', true) + rspamd_logger.infox(rspamd_config, + 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)', + rule.prefix, set.name, #data, set.ann.redis_key, ann_key) + lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, redis_save_cb, @@ -1131,8 +1152,20 @@ local function process_rules_settings() rule.prefix, selt.name) end + local function filter_symbols_predicate(sname) + local fl = rspamd_config:get_symbol_flags(sname) + if fl then + fl = lua_util.list_to_hash(fl) + + return not (fl.nostat or fl.idempotent or fl.skip) + end + + return false + end + -- Generic stuff - table.sort(selt.symbols) + table.sort(fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))) + selt.digest = lua_util.table_digest(selt.symbols) selt.prefix = redis_ann_prefix(rule, selt.name) -- 2.39.5