]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Normalize all ANN inputs
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 13 Sep 2016 12:22:38 +0000 (13:22 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 13 Sep 2016 12:22:38 +0000 (13:22 +0100)
src/plugins/lua/fann_scores.lua

index 4508afc17d6e7c3a7a64fff243aa3bdd5e39027b..30f7a618b22437bed904fbf0b483ab1a23b72659 100644 (file)
@@ -103,13 +103,19 @@ local function fann_images_function(task)
       ntotal = ntotal + 1
     end
   end
-
+  if ntotal > 0 then
+    njpg = njpg / ntotal
+    npng = npng / ntotal
+    nlarge = nlarge / ntotal
+    nsmall = nsmall / ntotal
+  end
   return {ntotal,njpg,npng,nlarge,nsmall}
 end
 
 local function fann_nparts_function(task)
   local nattachments = 0
   local ntextparts = 0
+  local totalparts = 1
 
   local tp = task:get_text_parts()
   if tp then
@@ -123,10 +129,11 @@ local function fann_nparts_function(task)
       if p:get_filename() then
         nattachments = nattachments + 1
       end
+      totalparts = totalparts + 1
     end
   end
 
-  return {ntextparts, nattachments}
+  return {ntextparts/totalparts, nattachments/totalparts}
 end
 
 local function fann_encoding_function(task)
@@ -158,16 +165,26 @@ local function fann_recipients_function(task)
     nsmtp = #(task:get_recipients('smtp'))
   end
 
+  if nmime > 0 then nmime = 1.0 / nmime end
+  if nsmtp > 0 then nsmtp = 1.0 / nsmtp end
+
   return {nmime,nsmtp}
 end
 
 local function fann_received_function(task)
-  return {#(task:get_received_headers())}
+  local ret = 0
+  local rh = task:get_received_headers()
+
+  if rh and #rh > 0 then
+    ret = 1 / #rh
+  end
+
+  return {ret}
 end
 
 local function fann_urls_function(task)
   if task:has_urls() then
-    return {#(task:get_urls())}
+    return {1.0 / #(task:get_urls())}
   end
 
   return {0}
@@ -246,12 +263,12 @@ local function symbols_to_fann_vector(syms, scores)
   local n = rspamd_config:get_symbols_count()
 
   each(function(s, score)
-     matched_symbols[s + 1] = score
+     matched_symbols[s + 1] = rspamd_util.tanh(score)
   end, zip(syms, scores))
 
   for i=1,n do
     if matched_symbols[i] then
-      learn_data[i] = math.abs(matched_symbols[i])
+      learn_data[i] = matched_symbols[i]
     else
       learn_data[i] = 0
     end