diff options
Diffstat (limited to 'lualib/plugins')
-rw-r--r-- | lualib/plugins/neural.lua | 161 |
1 files changed, 156 insertions, 5 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 64d21ce37..f677119fe 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -54,7 +54,8 @@ local default_options = { learning_spawned = false, ann_expire = 60 * 60 * 24 * 2, -- 2 days hidden_layer_mult = 1.5, -- number of neurons in the hidden layer - -- Check ROC curve and AUC in the ML literature + roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds. + roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1). spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached @@ -170,7 +171,8 @@ local redis_lua_script_maybe_lock = [[ -- key5 - expire in seconds -- key6 - current time -- key7 - old key --- key8 - optional PCA +-- key8 - ROC Thresholds +-- key9 - optional PCA local redis_lua_script_save_unlock = [[ local now = tonumber(KEYS[6]) redis.call('ZADD', KEYS[2], now, KEYS[4]) @@ -180,8 +182,9 @@ local redis_lua_script_save_unlock = [[ redis.call('HDEL', KEYS[1], 'lock') redis.call('HDEL', KEYS[7], 'lock') redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) - if KEYS[8] then - redis.call('HSET', KEYS[1], 'pca', KEYS[8]) + redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8]) + if KEYS[9] then + redis.call('HSET', KEYS[1], 'pca', KEYS[9]) end return 1 ]] @@ -239,6 +242,126 @@ local function learn_pca(inputs, max_inputs) return w end +-- This function computes optimal threshold using ROC for the given set of inputs. +-- Returns a threshold that minimizes: +-- alpha * (false_positive_rate) + beta * (false_negative_rate) +-- Where alpha is cost of false positive result +-- beta is cost of false negative result +local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) + + -- Sorts list x and list y based on the values in list x. + local sort_relative = function(x, y) + + local r = {} + + assert(#x == #y) + local n = #x + + local a = {} + local b = {} + for i=1,n do + r[i] = i + end + + local cmp = function(p, q) return p < q end + + table.sort(r, function(p, q) return cmp(x[p], x[q]) end) + + for i=1,n do + a[i] = x[r[i]] + b[i] = y[r[i]] + end + + return a, b + end + + local function get_scores(nn, input_vectors) + local scores = {} + for i=1,#inputs do + local score = nn:apply1(input_vectors[i], nn.pca)[1] + scores[#scores+1] = score + end + + return scores + end + + local fpr = {} + local fnr = {} + local scores = get_scores(ann, inputs) + + scores, outputs = sort_relative(scores, outputs) + + local n_samples = #outputs + local n_spam = 0 + local n_ham = 0 + local ham_count_ahead = {} + local spam_count_ahead = {} + local ham_count_behind = {} + local spam_count_behind = {} + + ham_count_ahead[n_samples + 1] = 0 + spam_count_ahead[n_samples + 1] = 0 + + for i=n_samples,1,-1 do + + if outputs[i][1] == 0 then + n_ham = n_ham + 1 + ham_count_ahead[i] = 1 + spam_count_ahead[i] = 0 + else + n_spam = n_spam + 1 + ham_count_ahead[i] = 0 + spam_count_ahead[i] = 1 + end + + ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1] + spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1] + end + + for i=1,n_samples do + if outputs[i][1] == 0 then + ham_count_behind[i] = 1 + spam_count_behind[i] = 0 + else + ham_count_behind[i] = 0 + spam_count_behind[i] = 1 + end + + if i ~= 1 then + ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1] + spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1] + end + end + + for i=1,n_samples do + fpr[i] = 0 + fnr[i] = 0 + + if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then + fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i]) + end + + if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then + fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1]) + end + end + + local p = n_spam / (n_spam + n_ham) + + local cost = {} + local min_cost_idx = 0 + local min_cost = math.huge + for i=1,n_samples do + cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i]) + if min_cost >= cost[i] then + min_cost = cost[i] + min_cost_idx = i + end + end + + return scores[min_cost_idx] +end + -- This function is intended to extend lock for ANN during training -- It registers periodic that increases locked key each 30 seconds unless -- `set.learning_spawned` is set to `true` @@ -497,6 +620,24 @@ local function spawn_train(params) params.rule.prefix, params.set.name) end + local roc_thresholds + if params.rule.roc_enabled then + local spam_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + 1 - params.rule.roc_misclassification_cost, + params.rule.roc_misclassification_cost) + local ham_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + params.rule.roc_misclassification_cost, + 1 - params.rule.roc_misclassification_cost) + roc_thresholds = {spam_threshold, ham_threshold} + end + + rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)", + roc_thresholds[1], roc_thresholds[2]) + if not seen_nan then -- Convert to strings as ucl cannot rspamd_text properly local pca_data @@ -506,6 +647,7 @@ local function spawn_train(params) local out = { ann_data = tostring(train_ann:save()), pca_data = pca_data, + roc_thresholds = roc_thresholds, } local final_data = ucl.to_format(out, 'msgpack') @@ -559,12 +701,19 @@ local function spawn_train(params) local parsed = parser:get_object() local ann_data = rspamd_util.zstd_compress(parsed.ann_data) local pca_data = parsed.pca_data + local roc_thresholds = parsed.roc_thresholds fill_set_ann(params.set, params.ann_key) if pca_data then params.set.ann.pca = rspamd_tensor.load(pca_data) pca_data = rspamd_util.zstd_compress(pca_data) end + + if roc_thresholds then + params.set.ann.roc_thresholds = roc_thresholds + end + + -- Deserialise ANN from the child process ann_trained = rspamd_kann.load(parsed.ann_data) local version = (params.set.ann.version or 0) + 1 @@ -581,6 +730,7 @@ local function spawn_train(params) } local profile_serialized = ucl.to_format(profile, 'json-compact', true) + local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true) rspamd_logger.infox(rspamd_config, 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', @@ -599,7 +749,8 @@ local function spawn_train(params) tostring(params.rule.ann_expire), tostring(os.time()), params.ann_key, -- old key to unlock... - pca_data + roc_thresholds_serialized, + pca_data, }) end end |