From bef70607af40943fa1626d1c0a32f94925d4f15a Mon Sep 17 00:00:00 2001 From: Pragadeesh Chandiran Date: Mon, 8 Nov 2021 00:13:04 -0500 Subject: [PATCH] [Feature] Add ROC feature to neural network plugin --- lualib/plugins/neural.lua | 161 +++++++++++++++++++++++++++++++++++-- src/plugins/lua/neural.lua | 51 +++++++++--- 2 files changed, 198 insertions(+), 14 deletions(-) diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 64d21ce37..f677119fe 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -54,7 +54,8 @@ local default_options = { learning_spawned = false, ann_expire = 60 * 60 * 24 * 2, -- 2 days hidden_layer_mult = 1.5, -- number of neurons in the hidden layer - -- Check ROC curve and AUC in the ML literature + roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds. + roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1). spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached @@ -170,7 +171,8 @@ local redis_lua_script_maybe_lock = [[ -- key5 - expire in seconds -- key6 - current time -- key7 - old key --- key8 - optional PCA +-- key8 - ROC Thresholds +-- key9 - optional PCA local redis_lua_script_save_unlock = [[ local now = tonumber(KEYS[6]) redis.call('ZADD', KEYS[2], now, KEYS[4]) @@ -180,8 +182,9 @@ local redis_lua_script_save_unlock = [[ redis.call('HDEL', KEYS[1], 'lock') redis.call('HDEL', KEYS[7], 'lock') redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) - if KEYS[8] then - redis.call('HSET', KEYS[1], 'pca', KEYS[8]) + redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8]) + if KEYS[9] then + redis.call('HSET', KEYS[1], 'pca', KEYS[9]) end return 1 ]] @@ -239,6 +242,126 @@ local function learn_pca(inputs, max_inputs) return w end +-- This function computes optimal threshold using ROC for the given set of inputs. +-- Returns a threshold that minimizes: +-- alpha * (false_positive_rate) + beta * (false_negative_rate) +-- Where alpha is cost of false positive result +-- beta is cost of false negative result +local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) + + -- Sorts list x and list y based on the values in list x. + local sort_relative = function(x, y) + + local r = {} + + assert(#x == #y) + local n = #x + + local a = {} + local b = {} + for i=1,n do + r[i] = i + end + + local cmp = function(p, q) return p < q end + + table.sort(r, function(p, q) return cmp(x[p], x[q]) end) + + for i=1,n do + a[i] = x[r[i]] + b[i] = y[r[i]] + end + + return a, b + end + + local function get_scores(nn, input_vectors) + local scores = {} + for i=1,#inputs do + local score = nn:apply1(input_vectors[i], nn.pca)[1] + scores[#scores+1] = score + end + + return scores + end + + local fpr = {} + local fnr = {} + local scores = get_scores(ann, inputs) + + scores, outputs = sort_relative(scores, outputs) + + local n_samples = #outputs + local n_spam = 0 + local n_ham = 0 + local ham_count_ahead = {} + local spam_count_ahead = {} + local ham_count_behind = {} + local spam_count_behind = {} + + ham_count_ahead[n_samples + 1] = 0 + spam_count_ahead[n_samples + 1] = 0 + + for i=n_samples,1,-1 do + + if outputs[i][1] == 0 then + n_ham = n_ham + 1 + ham_count_ahead[i] = 1 + spam_count_ahead[i] = 0 + else + n_spam = n_spam + 1 + ham_count_ahead[i] = 0 + spam_count_ahead[i] = 1 + end + + ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1] + spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1] + end + + for i=1,n_samples do + if outputs[i][1] == 0 then + ham_count_behind[i] = 1 + spam_count_behind[i] = 0 + else + ham_count_behind[i] = 0 + spam_count_behind[i] = 1 + end + + if i ~= 1 then + ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1] + spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1] + end + end + + for i=1,n_samples do + fpr[i] = 0 + fnr[i] = 0 + + if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then + fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i]) + end + + if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then + fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1]) + end + end + + local p = n_spam / (n_spam + n_ham) + + local cost = {} + local min_cost_idx = 0 + local min_cost = math.huge + for i=1,n_samples do + cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i]) + if min_cost >= cost[i] then + min_cost = cost[i] + min_cost_idx = i + end + end + + return scores[min_cost_idx] +end + -- This function is intended to extend lock for ANN during training -- It registers periodic that increases locked key each 30 seconds unless -- `set.learning_spawned` is set to `true` @@ -497,6 +620,24 @@ local function spawn_train(params) params.rule.prefix, params.set.name) end + local roc_thresholds + if params.rule.roc_enabled then + local spam_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + 1 - params.rule.roc_misclassification_cost, + params.rule.roc_misclassification_cost) + local ham_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + params.rule.roc_misclassification_cost, + 1 - params.rule.roc_misclassification_cost) + roc_thresholds = {spam_threshold, ham_threshold} + end + + rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)", + roc_thresholds[1], roc_thresholds[2]) + if not seen_nan then -- Convert to strings as ucl cannot rspamd_text properly local pca_data @@ -506,6 +647,7 @@ local function spawn_train(params) local out = { ann_data = tostring(train_ann:save()), pca_data = pca_data, + roc_thresholds = roc_thresholds, } local final_data = ucl.to_format(out, 'msgpack') @@ -559,12 +701,19 @@ local function spawn_train(params) local parsed = parser:get_object() local ann_data = rspamd_util.zstd_compress(parsed.ann_data) local pca_data = parsed.pca_data + local roc_thresholds = parsed.roc_thresholds fill_set_ann(params.set, params.ann_key) if pca_data then params.set.ann.pca = rspamd_tensor.load(pca_data) pca_data = rspamd_util.zstd_compress(pca_data) end + + if roc_thresholds then + params.set.ann.roc_thresholds = roc_thresholds + end + + -- Deserialise ANN from the child process ann_trained = rspamd_kann.load(parsed.ann_data) local version = (params.set.ann.version or 0) + 1 @@ -581,6 +730,7 @@ local function spawn_train(params) } local profile_serialized = ucl.to_format(profile, 'json-compact', true) + local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true) rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', @@ -599,7 +749,8 @@ local function spawn_train(params) tostring(params.rule.ann_expire), tostring(os.time()), params.ann_key, -- old key to unlock... - pca_data + roc_thresholds_serialized, + pca_data, }) end end diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 5458dd007..36eb9adaf 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -120,31 +120,47 @@ local function ann_scores_filter(task) if score > 0 then local result = score + + -- If spam_score_threshold is defined, override all other thresholds. + local spam_threshold = 0 + if rule.spam_score_threshold then + spam_threshold = rule.spam_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + spam_threshold = set.ann.roc_thresholds[1] + end - if not rule.spam_score_threshold or result >= rule.spam_score_threshold then + if result >= spam_threshold then if rule.flat_threshold_curve then task:insert_result(rule.symbol_spam, 1.0, symscore) else task:insert_result(rule.symbol_spam, result, symscore) end else - lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)', + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', rule.prefix, set.name, set.ann.version, symscore, - rule.spam_score_threshold) + spam_threshold) end else local result = -(score) - if not rule.ham_score_threshold or result >= rule.ham_score_threshold then + -- If ham_score_threshold is defined, override all other thresholds. + local ham_threshold = 0 + if rule.ham_score_threshold then + ham_threshold = rule.ham_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + ham_threshold = set.ann.roc_thresholds[2] + end + + if result >= ham_threshold then if rule.flat_threshold_curve then task:insert_result(rule.symbol_ham, 1.0, symscore) else task:insert_result(rule.symbol_ham, result, symscore) end else - lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)', + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', rule.prefix, set.name, set.ann.version, result, - rule.ham_score_threshold) + ham_threshold) end end end @@ -481,16 +497,32 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', rule.prefix, set.name, ann_key) end + if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then + if rule.roc_enabled then + local ucl = require "ucl" + local parser = ucl.parser() + local ok, parse_err = parser:parse_text(data[2]) + assert(ok, parse_err) + local roc_thresholds = parser:get_object() + set.ann.roc_thresholds = roc_thresholds + rspamd_logger.infox(rspamd_config, + 'loaded ROC thresholds for %s:%s; version=%s', + rule.prefix, set.name, profile.version) + rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds) + end + end + + if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then -- PCA table - local _err,pca_data = rspamd_util.zstd_decompress(data[2]) + local _err,pca_data = rspamd_util.zstd_decompress(data[3]) if pca_data then if rule.max_inputs then -- We can use PCA set.ann.pca = rspamd_tensor.load(pca_data) rspamd_logger.infox(rspamd_config, 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #data[2], profile.version) + rule.prefix, set.name, ann_key, #data[3], profile.version) else -- no need in pca, why is it there? rspamd_logger.warnx(rspamd_config, @@ -509,6 +541,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) end end end + else lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s', rule.prefix, set.name, ann_key) @@ -522,7 +555,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) false, -- is write data_cb, --callback 'HMGET', -- command - {ann_key, 'ann', 'pca'}, -- arguments + {ann_key, 'ann', 'roc_thresholds', 'pca'}, -- arguments {opaque_data = true} ) end -- 2.39.5