diff options
Diffstat (limited to 'lualib/plugins/neural.lua')
-rw-r--r-- | lualib/plugins/neural.lua | 264 |
1 files changed, 136 insertions, 128 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 05dace489..6e88ef21c 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -96,7 +96,6 @@ local module_config = rspamd_config:get_all_opt(N) settings = lua_util.override_defaults(settings, module_config) local redis_params = lua_redis.parse_redis_server('neural') - local redis_lua_script_vectors_len = "neural_train_size.lua" local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua" local redis_lua_script_maybe_lock = "neural_maybe_lock.lua" @@ -106,17 +105,17 @@ local redis_script_id = {} local function load_scripts() redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len, - redis_params) + redis_params) redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate, - redis_params) + redis_params) redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock, - redis_params) + redis_params) redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock, - redis_params) + redis_params) end local function create_ann(n, nlayers, rule) - -- We ignore number of layers so far when using kann + -- We ignore number of layers so far when using kann local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) local t = rspamd_kann.layer.input(n) t = rspamd_kann.transform.relu(t) @@ -146,7 +145,7 @@ local function learn_pca(inputs, max_inputs) -- scatter matrix is not filled with eigenvectors lua_util.debugm(N, 'eigenvalues: %s', eigenvals) local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) - for i=1,max_inputs do + for i = 1, max_inputs do w[i] = scatter_matrix[#scatter_matrix - i + 1] end @@ -172,15 +171,19 @@ local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) local a = {} local b = {} - for i=1,n do + for i = 1, n do r[i] = i end - local cmp = function(p, q) return p < q end + local cmp = function(p, q) + return p < q + end - table.sort(r, function(p, q) return cmp(x[p], x[q]) end) + table.sort(r, function(p, q) + return cmp(x[p], x[q]) + end) - for i=1,n do + for i = 1, n do a[i] = x[r[i]] b[i] = y[r[i]] end @@ -190,89 +193,89 @@ local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) local function get_scores(nn, input_vectors) local scores = {} - for i=1,#inputs do + for i = 1, #inputs do local score = nn:apply1(input_vectors[i], nn.pca)[1] - scores[#scores+1] = score + scores[#scores + 1] = score end return scores end local fpr = {} - local fnr = {} - local scores = get_scores(ann, inputs) - - scores, outputs = sort_relative(scores, outputs) - - local n_samples = #outputs - local n_spam = 0 - local n_ham = 0 - local ham_count_ahead = {} - local spam_count_ahead = {} - local ham_count_behind = {} - local spam_count_behind = {} - - ham_count_ahead[n_samples + 1] = 0 - spam_count_ahead[n_samples + 1] = 0 - - for i=n_samples,1,-1 do - - if outputs[i][1] == 0 then - n_ham = n_ham + 1 - ham_count_ahead[i] = 1 - spam_count_ahead[i] = 0 - else - n_spam = n_spam + 1 - ham_count_ahead[i] = 0 - spam_count_ahead[i] = 1 - end - - ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1] - spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1] - end - - for i=1,n_samples do + local fnr = {} + local scores = get_scores(ann, inputs) + + scores, outputs = sort_relative(scores, outputs) + + local n_samples = #outputs + local n_spam = 0 + local n_ham = 0 + local ham_count_ahead = {} + local spam_count_ahead = {} + local ham_count_behind = {} + local spam_count_behind = {} + + ham_count_ahead[n_samples + 1] = 0 + spam_count_ahead[n_samples + 1] = 0 + + for i = n_samples, 1, -1 do + + if outputs[i][1] == 0 then + n_ham = n_ham + 1 + ham_count_ahead[i] = 1 + spam_count_ahead[i] = 0 + else + n_spam = n_spam + 1 + ham_count_ahead[i] = 0 + spam_count_ahead[i] = 1 + end + + ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1] + spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1] + end + + for i = 1, n_samples do if outputs[i][1] == 0 then - ham_count_behind[i] = 1 - spam_count_behind[i] = 0 - else - ham_count_behind[i] = 0 - spam_count_behind[i] = 1 - end - - if i ~= 1 then - ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1] - spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1] - end - end - - for i=1,n_samples do - fpr[i] = 0 - fnr[i] = 0 - - if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then - fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i]) - end - - if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then - fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1]) - end - end - - local p = n_spam / (n_spam + n_ham) - - local cost = {} - local min_cost_idx = 0 - local min_cost = math.huge - for i=1,n_samples do - cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i]) - if min_cost >= cost[i] then - min_cost = cost[i] - min_cost_idx = i - end - end - - return scores[min_cost_idx] + ham_count_behind[i] = 1 + spam_count_behind[i] = 0 + else + ham_count_behind[i] = 0 + spam_count_behind[i] = 1 + end + + if i ~= 1 then + ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1] + spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1] + end + end + + for i = 1, n_samples do + fpr[i] = 0 + fnr[i] = 0 + + if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then + fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i]) + end + + if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then + fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1]) + end + end + + local p = n_spam / (n_spam + n_ham) + + local cost = {} + local min_cost_idx = 0 + local min_cost = math.huge + for i = 1, n_samples do + cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i]) + if min_cost >= cost[i] then + min_cost = cost[i] + min_cost_idx = i + end + end + + return scores[min_cost_idx] end -- This function is intended to extend lock for ANN during training @@ -299,7 +302,7 @@ local function register_lock_extender(rule, set, ev_base, ann_key) true, -- is write redis_lock_extend_cb, --callback 'HINCRBY', -- command - {ann_key, 'lock', '30'} + { ann_key, 'lock', '30' } ) else lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") @@ -337,7 +340,8 @@ local function can_push_train_vector(rule, task, learn_type, nspam, nham) end end return true - else -- Enough learns + else + -- Enough learns rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', learn_type, nspam) @@ -403,7 +407,7 @@ end -- Closure generator for unlock function local function gen_unlock_cb(rule, set, ann_key) - return function (err) + return function(err) if err then rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', rule.prefix, set.name, ann_key, err) @@ -426,7 +430,7 @@ local function redis_ann_prefix(rule, settings_name) -- We also need to count metatokens: local n = meta_functions.version return string.format('%s%d_%s_%d_%s', - settings.prefix, plugin_ver, rule.prefix, n, settings_name) + settings.prefix, plugin_ver, rule.prefix, n, settings_name) end -- This function receives training vectors, checks them, spawn learning and saves ANN in Redis @@ -449,7 +453,7 @@ local function spawn_train(params) -- Used to show parsed vectors in a convenient format (for debugging only) local function debug_vec(t) local ret = {} - for i,v in ipairs(t) do + for i, v in ipairs(t) do if v ~= 0 then ret[#ret + 1] = string.format('%d=%.2f', i, v) end @@ -462,14 +466,14 @@ local function spawn_train(params) -- KANN automatically shuffles those samples -- 1.0 is used for spam and -1.0 is used for ham -- It implies that output layer can express that (e.g. tanh output) - for _,e in ipairs(params.spam_vec) do + for _, e in ipairs(params.spam_vec) do inputs[#inputs + 1] = e - outputs[#outputs + 1] = {1.0} + outputs[#outputs + 1] = { 1.0 } --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) end - for _,e in ipairs(params.ham_vec) do + for _, e in ipairs(params.ham_vec) do inputs[#inputs + 1] = e - outputs[#outputs + 1] = {-1.0} + outputs[#outputs + 1] = { -1.0 } --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) end @@ -486,7 +490,7 @@ local function spawn_train(params) rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', params.rule.prefix, params.set.name, value_cost) - for i,e in ipairs(inputs) do + for i, e in ipairs(inputs) do lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', debug_vec(e), outputs[i][1]) end @@ -515,7 +519,7 @@ local function spawn_train(params) lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s", params.rule.prefix, params.set.name) - local ret,err = pcall(train_ann.train1, train_ann, + local ret, err = pcall(train_ann.train1, train_ann, inputs, outputs, { lr = params.rule.train.learning_rate, max_epoch = params.rule.train.max_iterations, @@ -536,19 +540,19 @@ local function spawn_train(params) local roc_thresholds = {} if params.rule.roc_enabled then local spam_threshold = get_roc_thresholds(train_ann, - inputs, - outputs, - 1 - params.rule.roc_misclassification_cost, - params.rule.roc_misclassification_cost) + inputs, + outputs, + 1 - params.rule.roc_misclassification_cost, + params.rule.roc_misclassification_cost) local ham_threshold = get_roc_thresholds(train_ann, - inputs, - outputs, - params.rule.roc_misclassification_cost, - 1 - params.rule.roc_misclassification_cost) - roc_thresholds = {spam_threshold, ham_threshold} + inputs, + outputs, + params.rule.roc_misclassification_cost, + 1 - params.rule.roc_misclassification_cost) + roc_thresholds = { spam_threshold, ham_threshold } rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)", - roc_thresholds[1], roc_thresholds[2]) + roc_thresholds[1], roc_thresholds[2]) end if not seen_nan then @@ -585,7 +589,7 @@ local function spawn_train(params) false, -- is write gen_unlock_cb(params.rule, params.set, params.ann_key), --callback 'HDEL', -- command - {params.ann_key, 'lock'} + { params.ann_key, 'lock' } ) else rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', @@ -605,7 +609,7 @@ local function spawn_train(params) true, -- is write gen_unlock_cb(params.rule, params.set, params.ann_key), --callback 'HDEL', -- command - {params.ann_key, 'lock'} + { params.ann_key, 'lock' } ) else local parser = ucl.parser() @@ -653,17 +657,17 @@ local function spawn_train(params) params.set.ann.redis_key, params.ann_key) lua_redis.exec_redis_script(redis_script_id.save_unlock, - {ev_base = params.ev_base, is_write = true}, + { ev_base = params.ev_base, is_write = true }, redis_save_cb, - {profile.redis_key, - redis_ann_prefix(params.rule, params.set.name), - ann_data, - profile_serialized, - tostring(params.rule.ann_expire), - tostring(os.time()), - params.ann_key, -- old key to unlock... - roc_thresholds_serialized, - pca_data, + { profile.redis_key, + redis_ann_prefix(params.rule, params.set.name), + ann_data, + profile_serialized, + tostring(params.rule.ann_expire), + tostring(os.time()), + params.ann_key, -- old key to unlock... + roc_thresholds_serialized, + pca_data, }) end end @@ -672,7 +676,7 @@ local function spawn_train(params) fill_set_ann(params.set, params.ann_key) end - params.worker:spawn_process{ + params.worker:spawn_process { func = train, on_complete = ann_trained, proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name), @@ -695,7 +699,9 @@ local function process_rules_settings() -- Ensure that we have an array... lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", rule.prefix, selt.name, profile) - if not profile[1] then profile = lua_util.keys(profile) end + if not profile[1] then + profile = lua_util.keys(profile) + end selt.symbols = profile else lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", @@ -758,7 +764,7 @@ local function process_rules_settings() }) end - for k,rule in pairs(settings.rules) do + for k, rule in pairs(settings.rules) do if not rule.allowed_settings then rule.allowed_settings = {} elseif rule.allowed_settings == 'all' then @@ -788,7 +794,7 @@ local function process_rules_settings() -- Now, for each allowed settings, we store sorted symbols + digest -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } - for s,_ in pairs(rule.allowed_settings) do + for s, _ in pairs(rule.allowed_settings) do -- Here, we have a name, set of symbols and local settings_id = s if type(settings_id) ~= 'number' then @@ -802,7 +808,7 @@ local function process_rules_settings() } process_settings_elt(rule, nelt) - for id,ex in pairs(rule.settings) do + for id, ex in pairs(rule.settings) do if type(ex) == 'table' then if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then -- Equal symbols, add reference @@ -829,7 +835,9 @@ local function get_rule_settings(task, rule) local sid = task:get_settings_id() or -1 local set = rule.settings[sid] - if not set then return nil end + if not set then + return nil + end while type(set) == 'number' do -- Reference to another settings! @@ -843,10 +851,10 @@ local function result_to_vector(task, profile) if not profile.zeros then -- Fill zeros vector local zeros = {} - for i=1,meta_functions.count_metatokens() do + for i = 1, meta_functions.count_metatokens() do zeros[i] = 0.0 end - for _,_ in ipairs(profile.symbols) do + for _, _ in ipairs(profile.symbols) do zeros[#zeros + 1] = 0.0 end profile.zeros = zeros @@ -855,7 +863,7 @@ local function result_to_vector(task, profile) local vec = lua_util.shallowcopy(profile.zeros) local mt = meta_functions.rspamd_gen_metatokens(task) - for i,v in ipairs(mt) do + for i, v in ipairs(mt) do vec[i] = v end |