]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Neural: Allow to balance FP/FN for the network
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 29 Apr 2021 18:41:03 +0000 (19:41 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 29 Apr 2021 18:41:43 +0000 (19:41 +0100)
lualib/plugins/neural.lua
src/plugins/lua/neural.lua

index c35fc0eebb6b1876e132a4a01f027b9d930b7514..f0d5cf582c10558521b1417d982422a90d50b53f 100644 (file)
@@ -54,6 +54,9 @@ local default_options = {
   learning_spawned = false,
   ann_expire = 60 * 60 * 24 * 2, -- 2 days
   hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+  -- Check ROC curve and AUC in the ML literature
+  spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
+  ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
   symbol_spam = 'NEURAL_SPAM',
   symbol_ham = 'NEURAL_HAM',
   max_inputs = nil, -- when PCA is used
index 894d42e30a106d0aceb4d72feb874345d6d6972b..ca11d9e66697cf1368665d463bee71f58a412b0b 100644 (file)
@@ -119,10 +119,24 @@ local function ann_scores_filter(task)
 
       if score > 0 then
         local result = score
-        task:insert_result(rule.symbol_spam, result, symscore)
+
+        if not rule.spam_score_threshold or result >= rule.spam_score_threshold then
+          task:insert_result(rule.symbol_spam, result, symscore)
+        else
+          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)',
+              rule.prefix, set.name, set.ann.version, symscore,
+              rule.spam_score_threshold)
+        end
       else
         local result = -(score)
-        task:insert_result(rule.symbol_ham, result, symscore)
+
+        if not rule.ham_score_threshold or result >= rule.ham_score_threshold then
+          task:insert_result(rule.symbol_ham, result, symscore)
+        else
+          lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)',
+              rule.prefix, set.name, set.ann.version, result,
+              rule.ham_score_threshold)
+        end
       end
     end
   end