aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/plugins
diff options
context:
space:
mode:
authorPragadeesh Chandiran <pchandiran@mimecast.com>2021-11-08 00:13:04 -0500
committerPragadeesh Chandiran <pchandiran@mimecast.com>2021-11-15 03:12:33 -0500
commitbef70607af40943fa1626d1c0a32f94925d4f15a (patch)
tree13569f5c909aca17201ad28f580b20d28170e33a /lualib/plugins
parent711dca480131632fa2f352c44264fa8d65496fd3 (diff)
downloadrspamd-bef70607af40943fa1626d1c0a32f94925d4f15a.tar.gz
rspamd-bef70607af40943fa1626d1c0a32f94925d4f15a.zip
[Feature] Add ROC feature to neural network plugin
Diffstat (limited to 'lualib/plugins')
-rw-r--r--lualib/plugins/neural.lua161
1 files changed, 156 insertions, 5 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index 64d21ce37..f677119fe 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -54,7 +54,8 @@ local default_options = {
learning_spawned = false,
ann_expire = 60 * 60 * 24 * 2, -- 2 days
hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
- -- Check ROC curve and AUC in the ML literature
+ roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+ roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
@@ -170,7 +171,8 @@ local redis_lua_script_maybe_lock = [[
-- key5 - expire in seconds
-- key6 - current time
-- key7 - old key
--- key8 - optional PCA
+-- key8 - ROC Thresholds
+-- key9 - optional PCA
local redis_lua_script_save_unlock = [[
local now = tonumber(KEYS[6])
redis.call('ZADD', KEYS[2], now, KEYS[4])
@@ -180,8 +182,9 @@ local redis_lua_script_save_unlock = [[
redis.call('HDEL', KEYS[1], 'lock')
redis.call('HDEL', KEYS[7], 'lock')
redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
- if KEYS[8] then
- redis.call('HSET', KEYS[1], 'pca', KEYS[8])
+ redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
+ if KEYS[9] then
+ redis.call('HSET', KEYS[1], 'pca', KEYS[9])
end
return 1
]]
@@ -239,6 +242,126 @@ local function learn_pca(inputs, max_inputs)
return w
end
+-- This function computes optimal threshold using ROC for the given set of inputs.
+-- Returns a threshold that minimizes:
+-- alpha * (false_positive_rate) + beta * (false_negative_rate)
+-- Where alpha is cost of false positive result
+-- beta is cost of false negative result
+local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
+
+ -- Sorts list x and list y based on the values in list x.
+ local sort_relative = function(x, y)
+
+ local r = {}
+
+ assert(#x == #y)
+ local n = #x
+
+ local a = {}
+ local b = {}
+ for i=1,n do
+ r[i] = i
+ end
+
+ local cmp = function(p, q) return p < q end
+
+ table.sort(r, function(p, q) return cmp(x[p], x[q]) end)
+
+ for i=1,n do
+ a[i] = x[r[i]]
+ b[i] = y[r[i]]
+ end
+
+ return a, b
+ end
+
+ local function get_scores(nn, input_vectors)
+ local scores = {}
+ for i=1,#inputs do
+ local score = nn:apply1(input_vectors[i], nn.pca)[1]
+ scores[#scores+1] = score
+ end
+
+ return scores
+ end
+
+ local fpr = {}
+ local fnr = {}
+ local scores = get_scores(ann, inputs)
+
+ scores, outputs = sort_relative(scores, outputs)
+
+ local n_samples = #outputs
+ local n_spam = 0
+ local n_ham = 0
+ local ham_count_ahead = {}
+ local spam_count_ahead = {}
+ local ham_count_behind = {}
+ local spam_count_behind = {}
+
+ ham_count_ahead[n_samples + 1] = 0
+ spam_count_ahead[n_samples + 1] = 0
+
+ for i=n_samples,1,-1 do
+
+ if outputs[i][1] == 0 then
+ n_ham = n_ham + 1
+ ham_count_ahead[i] = 1
+ spam_count_ahead[i] = 0
+ else
+ n_spam = n_spam + 1
+ ham_count_ahead[i] = 0
+ spam_count_ahead[i] = 1
+ end
+
+ ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
+ spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
+ end
+
+ for i=1,n_samples do
+ if outputs[i][1] == 0 then
+ ham_count_behind[i] = 1
+ spam_count_behind[i] = 0
+ else
+ ham_count_behind[i] = 0
+ spam_count_behind[i] = 1
+ end
+
+ if i ~= 1 then
+ ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
+ spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
+ end
+ end
+
+ for i=1,n_samples do
+ fpr[i] = 0
+ fnr[i] = 0
+
+ if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
+ fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
+ end
+
+ if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
+ fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
+ end
+ end
+
+ local p = n_spam / (n_spam + n_ham)
+
+ local cost = {}
+ local min_cost_idx = 0
+ local min_cost = math.huge
+ for i=1,n_samples do
+ cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
+ if min_cost >= cost[i] then
+ min_cost = cost[i]
+ min_cost_idx = i
+ end
+ end
+
+ return scores[min_cost_idx]
+end
+
-- This function is intended to extend lock for ANN during training
-- It registers periodic that increases locked key each 30 seconds unless
-- `set.learning_spawned` is set to `true`
@@ -497,6 +620,24 @@ local function spawn_train(params)
params.rule.prefix, params.set.name)
end
+ local roc_thresholds
+ if params.rule.roc_enabled then
+ local spam_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ 1 - params.rule.roc_misclassification_cost,
+ params.rule.roc_misclassification_cost)
+ local ham_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ params.rule.roc_misclassification_cost,
+ 1 - params.rule.roc_misclassification_cost)
+ roc_thresholds = {spam_threshold, ham_threshold}
+ end
+
+ rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
+ roc_thresholds[1], roc_thresholds[2])
+
if not seen_nan then
-- Convert to strings as ucl cannot rspamd_text properly
local pca_data
@@ -506,6 +647,7 @@ local function spawn_train(params)
local out = {
ann_data = tostring(train_ann:save()),
pca_data = pca_data,
+ roc_thresholds = roc_thresholds,
}
local final_data = ucl.to_format(out, 'msgpack')
@@ -559,12 +701,19 @@ local function spawn_train(params)
local parsed = parser:get_object()
local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
local pca_data = parsed.pca_data
+ local roc_thresholds = parsed.roc_thresholds
fill_set_ann(params.set, params.ann_key)
if pca_data then
params.set.ann.pca = rspamd_tensor.load(pca_data)
pca_data = rspamd_util.zstd_compress(pca_data)
end
+
+ if roc_thresholds then
+ params.set.ann.roc_thresholds = roc_thresholds
+ end
+
+
-- Deserialise ANN from the child process
ann_trained = rspamd_kann.load(parsed.ann_data)
local version = (params.set.ann.version or 0) + 1
@@ -581,6 +730,7 @@ local function spawn_train(params)
}
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+ local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
rspamd_logger.infox(rspamd_config,
'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
@@ -599,7 +749,8 @@ local function spawn_train(params)
tostring(params.rule.ann_expire),
tostring(os.time()),
params.ann_key, -- old key to unlock...
- pca_data
+ roc_thresholds_serialized,
+ pca_data,
})
end
end