diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2021-04-29 19:41:03 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2021-04-29 19:41:43 +0100 |
commit | 82e588390a7f0dc000e74497cfb84e25dcbfafe5 (patch) | |
tree | a21472fe460d3c46c4ac284dad7c3fa12865dda5 | |
parent | b9f9beebe869b5800e80a8dbbc3d49a9c9457062 (diff) | |
download | rspamd-82e588390a7f0dc000e74497cfb84e25dcbfafe5.tar.gz rspamd-82e588390a7f0dc000e74497cfb84e25dcbfafe5.zip |
[Feature] Neural: Allow to balance FP/FN for the network
-rw-r--r-- | lualib/plugins/neural.lua | 3 | ||||
-rw-r--r-- | src/plugins/lua/neural.lua | 18 |
2 files changed, 19 insertions, 2 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index c35fc0eeb..f0d5cf582 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -54,6 +54,9 @@ local default_options = { learning_spawned = false, ann_expire = 60 * 60 * 24 * 2, -- 2 days hidden_layer_mult = 1.5, -- number of neurons in the hidden layer + -- Check ROC curve and AUC in the ML literature + spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) + ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) symbol_spam = 'NEURAL_SPAM', symbol_ham = 'NEURAL_HAM', max_inputs = nil, -- when PCA is used diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 894d42e30..ca11d9e66 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -119,10 +119,24 @@ local function ann_scores_filter(task) if score > 0 then local result = score - task:insert_result(rule.symbol_spam, result, symscore) + + if not rule.spam_score_threshold or result >= rule.spam_score_threshold then + task:insert_result(rule.symbol_spam, result, symscore) + else + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam_score_threshold)', + rule.prefix, set.name, set.ann.version, symscore, + rule.spam_score_threshold) + end else local result = -(score) - task:insert_result(rule.symbol_ham, result, symscore) + + if not rule.ham_score_threshold or result >= rule.ham_score_threshold then + task:insert_result(rule.symbol_ham, result, symscore) + else + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham_score_threshold)', + rule.prefix, set.name, set.ann.version, result, + rule.ham_score_threshold) + end end end end |