]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add ROC feature to neural network plugin 3980/head
authorPragadeesh Chandiran <pchandiran@mimecast.com>
Mon, 8 Nov 2021 05:13:04 +0000 (00:13 -0500)
committerPragadeesh Chandiran <pchandiran@mimecast.com>
Mon, 15 Nov 2021 08:12:33 +0000 (03:12 -0500)
lualib/plugins/neural.lua
src/plugins/lua/neural.lua

index 64d21ce37360a63b55dd366faa307bdc65ea574c..f677119fe24b5ad54813a3fb89c850574155a4de 100644 (file)
@@ -54,7 +54,8 @@ local default_options = {
   learning_spawned = false,
   ann_expire = 60 * 60 * 24 * 2, -- 2 days
   hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
-  -- Check ROC curve and AUC in the ML literature
+  roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+  roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
   spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
   ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
   flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
@@ -170,7 +171,8 @@ local redis_lua_script_maybe_lock = [[
 -- key5 - expire in seconds
 -- key6 - current time
 -- key7 - old key
--- key8 - optional PCA
+-- key8 - ROC Thresholds
+-- key9 - optional PCA
 local redis_lua_script_save_unlock = [[
   local now = tonumber(KEYS[6])
   redis.call('ZADD', KEYS[2], now, KEYS[4])
@@ -180,8 +182,9 @@ local redis_lua_script_save_unlock = [[
   redis.call('HDEL', KEYS[1], 'lock')
   redis.call('HDEL', KEYS[7], 'lock')
   redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
-  if KEYS[8] then
-    redis.call('HSET', KEYS[1], 'pca', KEYS[8])
+  redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
+  if KEYS[9] then
+    redis.call('HSET', KEYS[1], 'pca', KEYS[9])
   end
   return 1
 ]]
@@ -239,6 +242,126 @@ local function learn_pca(inputs, max_inputs)
   return w
 end
 
+-- This function computes optimal threshold using ROC for the given set of inputs.
+-- Returns a threshold that minimizes:
+--        alpha * (false_positive_rate)  +  beta * (false_negative_rate)
+--        Where alpha is cost of false positive result
+--              beta is cost of false negative result
+local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
+
+  -- Sorts list x and list y based on the values in list x.
+  local sort_relative = function(x, y)
+
+    local r = {}
+
+    assert(#x == #y)
+    local n = #x
+
+    local a = {}
+    local b = {}
+    for i=1,n do
+      r[i] = i
+    end
+
+    local cmp = function(p, q) return p < q end
+
+    table.sort(r, function(p, q) return cmp(x[p], x[q]) end)
+
+    for i=1,n do 
+      a[i] = x[r[i]]
+      b[i] = y[r[i]]
+    end
+
+    return a, b
+  end
+
+  local function get_scores(nn, input_vectors)
+    local scores = {}
+    for i=1,#inputs do
+      local score = nn:apply1(input_vectors[i], nn.pca)[1]
+      scores[#scores+1] = score
+    end
+
+    return scores
+  end
+
+  local fpr = {}
+       local fnr = {}
+       local scores = get_scores(ann, inputs)
+
+       scores, outputs = sort_relative(scores, outputs)
+
+       local n_samples = #outputs
+       local n_spam = 0
+       local n_ham = 0
+       local ham_count_ahead = {}
+       local spam_count_ahead = {}
+       local ham_count_behind = {}
+       local spam_count_behind = {}
+
+       ham_count_ahead[n_samples + 1] = 0
+       spam_count_ahead[n_samples + 1] = 0
+
+       for i=n_samples,1,-1 do
+
+               if outputs[i][1] == 0 then
+                       n_ham = n_ham + 1
+                       ham_count_ahead[i] = 1
+                       spam_count_ahead[i] = 0
+               else
+                       n_spam = n_spam + 1
+                       ham_count_ahead[i] = 0
+                       spam_count_ahead[i] = 1
+               end
+
+               ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
+               spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
+       end
+
+       for i=1,n_samples do
+    if outputs[i][1] == 0 then
+                       ham_count_behind[i] = 1
+                       spam_count_behind[i] = 0
+               else
+                       ham_count_behind[i] = 0
+                       spam_count_behind[i] = 1
+               end
+
+               if i ~= 1 then
+                       ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
+                       spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
+               end
+       end
+
+       for i=1,n_samples do
+               fpr[i] = 0
+               fnr[i] = 0
+
+               if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
+                       fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
+               end
+
+               if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
+                       fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
+               end
+       end
+
+       local p = n_spam / (n_spam + n_ham)
+
+       local cost = {}
+       local min_cost_idx = 0
+       local min_cost = math.huge
+       for i=1,n_samples do
+               cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
+               if min_cost >= cost[i] then
+                       min_cost = cost[i]
+                       min_cost_idx = i
+               end
+       end
+
+       return scores[min_cost_idx]
+end
+
 -- This function is intended to extend lock for ANN during training
 -- It registers periodic that increases locked key each 30 seconds unless
 -- `set.learning_spawned` is set to `true`
@@ -497,6 +620,24 @@ local function spawn_train(params)
             params.rule.prefix, params.set.name)
       end
 
+      local roc_thresholds
+      if params.rule.roc_enabled then
+        local spam_threshold = get_roc_thresholds(train_ann,
+                                                  inputs,
+                                                  outputs,
+                                                  1 - params.rule.roc_misclassification_cost,
+                                                  params.rule.roc_misclassification_cost)
+        local ham_threshold = get_roc_thresholds(train_ann,
+                                                  inputs,
+                                                  outputs,
+                                                  params.rule.roc_misclassification_cost,
+                                                  1 - params.rule.roc_misclassification_cost)
+        roc_thresholds = {spam_threshold, ham_threshold}
+      end
+
+      rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
+                              roc_thresholds[1], roc_thresholds[2])
+
       if not seen_nan then
         -- Convert to strings as ucl cannot rspamd_text properly
         local pca_data
@@ -506,6 +647,7 @@ local function spawn_train(params)
         local out = {
           ann_data = tostring(train_ann:save()),
           pca_data = pca_data,
+          roc_thresholds = roc_thresholds,
         }
 
         local final_data = ucl.to_format(out, 'msgpack')
@@ -559,12 +701,19 @@ local function spawn_train(params)
         local parsed = parser:get_object()
         local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
         local pca_data = parsed.pca_data
+        local roc_thresholds = parsed.roc_thresholds
 
         fill_set_ann(params.set, params.ann_key)
         if pca_data then
           params.set.ann.pca = rspamd_tensor.load(pca_data)
           pca_data = rspamd_util.zstd_compress(pca_data)
         end
+
+        if roc_thresholds then
+          params.set.ann.roc_thresholds = roc_thresholds
+        end
+
+
         -- Deserialise ANN from the child process
         ann_trained = rspamd_kann.load(parsed.ann_data)
         local version = (params.set.ann.version or 0) + 1
@@ -581,6 +730,7 @@ local function spawn_train(params)
         }
 
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+        local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
 
         rspamd_logger.infox(rspamd_config,
             'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
@@ -599,7 +749,8 @@ local function spawn_train(params)
              tostring(params.rule.ann_expire),
              tostring(os.time()),
              params.ann_key, -- old key to unlock...
-             pca_data
+             roc_thresholds_serialized,
+             pca_data,
             })
       end
     end
index 5458dd007a80b04fb3360b71fa51793d0902e700..36eb9adaf7a8e82d83c1e606d9881d66329c2056 100644 (file)
@@ -120,31 +120,47 @@ local function ann_scores_filter(task)
 
       if score > 0 then
         local result = score
+        
+        -- If spam_score_threshold is defined, override all other thresholds.
+        local spam_threshold = 0
+        if rule.spam_score_threshold then
+          spam_threshold = rule.spam_score_threshold
+        elseif rule.roc_enabled and not set.ann.roc_thresholds then
+          spam_threshold = set.ann.roc_thresholds[1]
+        end
 
-        if not rule.spam_score_threshold or result >= rule.spam_score_threshold then
+        if result >= spam_threshold then
           if rule.flat_threshold_curve then
             task:insert_result(rule.symbol_spam, 1.0, symscore)
           else
             task:insert_result(rule.symbol_spam, result, symscore)
           end
         else
-          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)',
+          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
               rule.prefix, set.name, set.ann.version, symscore,
-              rule.spam_score_threshold)
+              spam_threshold)
         end
       else
         local result = -(score)
 
-        if not rule.ham_score_threshold or result >= rule.ham_score_threshold then
+        -- If ham_score_threshold is defined, override all other thresholds.
+        local ham_threshold = 0
+        if rule.ham_score_threshold then
+          ham_threshold = rule.ham_score_threshold
+        elseif rule.roc_enabled and not set.ann.roc_thresholds then
+          ham_threshold = set.ann.roc_thresholds[2]
+        end
+
+        if result >= ham_threshold then
           if rule.flat_threshold_curve then
             task:insert_result(rule.symbol_ham, 1.0, symscore)
           else
             task:insert_result(rule.symbol_ham, result, symscore)
           end
         else
-          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)',
+          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
               rule.prefix, set.name, set.ann.version, result,
-              rule.ham_score_threshold)
+              ham_threshold)
         end
       end
     end
@@ -481,16 +497,32 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
           lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
               rule.prefix, set.name, ann_key)
         end
+
         if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+          if rule.roc_enabled then
+            local ucl = require "ucl"
+            local parser = ucl.parser()
+            local ok, parse_err = parser:parse_text(data[2])
+            assert(ok, parse_err)
+            local roc_thresholds = parser:get_object()
+            set.ann.roc_thresholds = roc_thresholds
+            rspamd_logger.infox(rspamd_config, 
+                                'loaded ROC thresholds for %s:%s; version=%s',
+                                rule.prefix, set.name, profile.version)
+            rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
+          end
+        end
+
+        if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then
           -- PCA table
-          local _err,pca_data = rspamd_util.zstd_decompress(data[2])
+          local _err,pca_data = rspamd_util.zstd_decompress(data[3])
           if pca_data then
             if rule.max_inputs then
               -- We can use PCA
               set.ann.pca = rspamd_tensor.load(pca_data)
               rspamd_logger.infox(rspamd_config,
                   'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
-                  rule.prefix, set.name, ann_key, #data[2], profile.version)
+                  rule.prefix, set.name, ann_key, #data[3], profile.version)
             else
               -- no need in pca, why is it there?
               rspamd_logger.warnx(rspamd_config,
@@ -509,6 +541,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
             end
           end
         end
+
       else
         lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
             rule.prefix, set.name, ann_key)
@@ -522,7 +555,7 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff)
       false, -- is write
       data_cb, --callback
       'HMGET', -- command
-      {ann_key, 'ann', 'pca'}, -- arguments
+      {ann_key, 'ann', 'roc_thresholds', 'pca'}, -- arguments
       {opaque_data = true}
   )
 end