From c29fd5981b0f24bae5d7f6389f3f1f87ddab4b93 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 13 Sep 2016 13:22:38 +0100 Subject: [PATCH] [Feature] Normalize all ANN inputs --- src/plugins/lua/fann_scores.lua | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 4508afc17..30f7a618b 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -103,13 +103,19 @@ local function fann_images_function(task) ntotal = ntotal + 1 end end - + if ntotal > 0 then + njpg = njpg / ntotal + npng = npng / ntotal + nlarge = nlarge / ntotal + nsmall = nsmall / ntotal + end return {ntotal,njpg,npng,nlarge,nsmall} end local function fann_nparts_function(task) local nattachments = 0 local ntextparts = 0 + local totalparts = 1 local tp = task:get_text_parts() if tp then @@ -123,10 +129,11 @@ local function fann_nparts_function(task) if p:get_filename() then nattachments = nattachments + 1 end + totalparts = totalparts + 1 end end - return {ntextparts, nattachments} + return {ntextparts/totalparts, nattachments/totalparts} end local function fann_encoding_function(task) @@ -158,16 +165,26 @@ local function fann_recipients_function(task) nsmtp = #(task:get_recipients('smtp')) end + if nmime > 0 then nmime = 1.0 / nmime end + if nsmtp > 0 then nsmtp = 1.0 / nsmtp end + return {nmime,nsmtp} end local function fann_received_function(task) - return {#(task:get_received_headers())} + local ret = 0 + local rh = task:get_received_headers() + + if rh and #rh > 0 then + ret = 1 / #rh + end + + return {ret} end local function fann_urls_function(task) if task:has_urls() then - return {#(task:get_urls())} + return {1.0 / #(task:get_urls())} end return {0} @@ -246,12 +263,12 @@ local function symbols_to_fann_vector(syms, scores) local n = rspamd_config:get_symbols_count() each(function(s, score) - matched_symbols[s + 1] = score + matched_symbols[s + 1] = rspamd_util.tanh(score) end, zip(syms, scores)) for i=1,n do if matched_symbols[i] then - learn_data[i] = math.abs(matched_symbols[i]) + learn_data[i] = matched_symbols[i] else learn_data[i] = 0 end -- 2.39.5