]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Neural: Allow to have flat classification if needed
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 29 Apr 2021 18:44:40 +0000 (19:44 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 29 Apr 2021 18:44:40 +0000 (19:44 +0100)
lualib/plugins/neural.lua
src/plugins/lua/neural.lua

index f0d5cf582c10558521b1417d982422a90d50b53f..5571335913f494c74c4f0164bda72c1db750f7ea 100644 (file)
@@ -57,6 +57,7 @@ local default_options = {
   -- Check ROC curve and AUC in the ML literature
   spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
   ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
+  flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
   symbol_spam = 'NEURAL_SPAM',
   symbol_ham = 'NEURAL_HAM',
   max_inputs = nil, -- when PCA is used
index ca11d9e66697cf1368665d463bee71f58a412b0b..2ac8df59f33e55dbfdc471aff6da530624cf53d1 100644 (file)
@@ -121,7 +121,11 @@ local function ann_scores_filter(task)
         local result = score
 
         if not rule.spam_score_threshold or result >= rule.spam_score_threshold then
-          task:insert_result(rule.symbol_spam, result, symscore)
+          if rule.flat_threshold_curve then
+            task:insert_result(rule.symbol_spam, 1.0, symscore)
+          else
+            task:insert_result(rule.symbol_spam, result, symscore)
+          end
         else
           lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)',
               rule.prefix, set.name, set.ann.version, symscore,
@@ -131,7 +135,11 @@ local function ann_scores_filter(task)
         local result = -(score)
 
         if not rule.ham_score_threshold or result >= rule.ham_score_threshold then
-          task:insert_result(rule.symbol_ham, result, symscore)
+          if rule.flat_threshold_curve then
+            task:insert_result(rule.symbol_ham, 1.0, symscore)
+          else
+            task:insert_result(rule.symbol_ham, result, symscore)
+          end
         else
           lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)',
               rule.prefix, set.name, set.ann.version, result,