]> source.dussan.org Git - rspamd.git/commitdiff
Rescore logging and reporting improvements
authorSteve Freegard <steve@stevefreegard.com>
Mon, 19 Mar 2018 12:09:40 +0000 (12:09 +0000)
committerSteve Freegard <steve@stevefreegard.com>
Mon, 19 Mar 2018 12:09:40 +0000 (12:09 +0000)
lualib/rspamadm/rescore.lua
lualib/rspamadm/rescore_utility.lua

index c8348caa3ff8dadeec3e08d8ac6b021d7895f08a..87e0ea2c5ab3fa9cbb521913f42781c10ed811d3 100644 (file)
@@ -182,9 +182,9 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
 
   logs = update_logs(logs, new_symbol_scores)
 
-  local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs, threshold)
+  local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold)
 
-  return file_stats.fscore
+  return file_stats.fscore, all_fps, all_fns
 end
 
 local function print_stats(logs, threshold)
@@ -196,6 +196,7 @@ F-score: %.2f
 False positive rate: %.2f %%
 False negative rate: %.2f %%
 Overall accuracy: %.2f %%
+Slowest message: %.2f (%s)
 ]]
 
   logger.message("\nStatistics at threshold: " .. threshold)
@@ -204,7 +205,9 @@ Overall accuracy: %.2f %%
       file_stats.fscore,
       file_stats.false_positive_rate,
       file_stats.false_negative_rate,
-      file_stats.overall_accuracy))
+      file_stats.overall_accuracy,
+      file_stats.slowest,
+      file_stats.slowest_file))
 
 end
 
@@ -463,6 +466,67 @@ return function (args, cfg)
   local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
       ignore_symbols)
 
+  -- Display hit frequencies
+  if opts['z'] then
+      local file_stats, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
+      local t = {}
+      for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
+      function compare_symbols (a, b) 
+          if (a.spam_overall ~= b.spam_overall) then
+              return b.spam_overall < a.spam_overall
+          end
+          if (b.spam_hits ~= a.spam_hits) then
+              return b.spam_hits < a.spam_hits
+          end
+          return b.ham_hits < a.ham_hits
+      end
+      table.sort(t, compare_symbols)
+      logger.message(string.format("%-40s %6s %6s %6s %6s %6s %6s %6s",
+                     "NAME", "HITS", "HAM", "HAM%", "SPAM", "SPAM%", "S/O", "OVER%"))
+      for _, symbol_stats in pairs(t) do
+          logger.message(
+              string.format("%-40s %6d %6d %6.2f %6d %6.2f %6.2f %6.2f", 
+                  symbol_stats.name, 
+                  symbol_stats.no_of_hits,
+                  symbol_stats.ham_hits,
+                  lua_util.round(symbol_stats.ham_percent,2),
+                  symbol_stats.spam_hits,
+                  lua_util.round(symbol_stats.spam_percent,2),
+                  lua_util.round(symbol_stats.spam_overall,2),
+                  lua_util.round(symbol_stats.overall, 2)
+              )
+          )
+      end
+
+      -- Print file statistics
+      print_stats(logs, threshold)
+
+      -- Work out how many symbols weren't seen in the corpus
+      local symbols_no_hits = {}
+      local total_symbols = 0
+      for sym in pairs(original_symbol_scores) do
+          total_symbols = total_symbols + 1
+          if (all_symbols_stats[sym] == nil) then
+              table.insert(symbols_no_hits, sym)
+          end
+      end
+      if (#symbols_no_hits > 0) then
+          table.sort(symbols_no_hits)
+          -- Calculate percentage of rules with no hits
+          local nhpct = lua_util.round((#symbols_no_hits/total_symbols)*100,2)
+          logger.message(
+              string.format('\nFound %s (%-.2f%%) symbols out of %s with no hits in corpus:', 
+                            #symbols_no_hits, nhpct, total_symbols
+              )
+          )
+          for _, symbol in pairs(symbols_no_hits) do
+              logger.messagex('%s', symbol)
+          end
+      end
+
+      return
+  end
+
   shuffle(logs)
   torch.setdefaulttensortype('torch.FloatTensor')
 
@@ -471,7 +535,6 @@ return function (args, cfg)
 
   local dataset = make_dataset_from_logs(train_logs, all_symbols, reject_score)
 
-
   -- Start of perceptron training
   local input_size = #all_symbols
   torch.setnumthreads(opts['threads'])
@@ -490,6 +553,8 @@ return function (args, cfg)
   local best_weights = linear_module.weight[1]:clone()
   local best_learning_rate
   local best_weight_decay
+  local all_fps
+  local all_fns
 
   for _,lr in ipairs(learning_rates) do
     for _,wd in ipairs(penalty_weights) do
@@ -502,7 +567,7 @@ return function (args, cfg)
             initial_weights)
       end
 
-      local fscore = calculate_fscore_from_weights(cv_logs,
+      local fscore, fps, fns = calculate_fscore_from_weights(cv_logs,
           all_symbols,
           linear_module.weight[1],
           threshold)
@@ -515,6 +580,8 @@ return function (args, cfg)
         best_weight_decay = wd
         best_fscore = fscore
         best_weights = linear_module.weight[1]:clone()
+        all_fps = fps
+        all_fns = fns
       end
     end
   end
@@ -533,7 +600,6 @@ return function (args, cfg)
     print_score_diff(new_symbol_scores, original_symbol_scores)
   end
 
-
   -- Pre-rescore test stats
   logger.message("\n\nPre-rescore test stats\n")
   test_logs = update_logs(test_logs, original_symbol_scores)
@@ -546,4 +612,19 @@ return function (args, cfg)
 
   logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
       best_fscore, best_learning_rate, best_weight_decay)
-end
\ No newline at end of file
+
+  -- Show all FPs/FNs, useful for corpus checking and rule creation/modification
+  if (all_fps and #all_fps > 0) then
+      logger.message("\nFalse-Positives:")
+      for _, fp in pairs(all_fps) do
+          logger.messagex('%s', fp)
+      end
+  end
+
+  if (all_fns and #all_fns > 0) then
+      logger.message("\nFalse-Negatives:")
+      for _, fn in pairs(all_fns) do
+          logger.messagex('%s', fn)
+      end
+  end
+end
index 7f3f4007847350e01e83c3cad2509bc432713e56..2a9372d4e3f12f010f1d345c874b7bafc2c1de3b 100644 (file)
@@ -11,7 +11,7 @@ function utility.get_all_symbols(logs, ignore_symbols)
 
   for _, line in pairs(logs) do
     line = lua_util.rspamd_str_split(line, " ")
-    for i=4,#line do
+    for i=4,(#line-2) do
       line[i] = line[i]:gsub("%s+", "")
       if not symbols_set[line[i]] then
         symbols_set[line[i]] = true
@@ -54,7 +54,7 @@ function utility.get_all_logs(dir_path)
     dir_path = dir_path:sub(1, #dir_path -1)
   end
 
-  local files = rspamd_util.glob(dir_path .. "/*")
+  local files = rspamd_util.glob(dir_path .. "/*.log")
   local all_logs = {}
 
   for _, file in pairs(files) do
@@ -92,10 +92,15 @@ function utility.generate_statistics_from_logs(logs, threshold)
     false_negative_rate = 0,
     false_positive_rate = 0,
     overall_accuracy = 0,
-    fscore = 0
+    fscore = 0,
+    avg_scan_time = 0,
+    slowest_file = nil,
+    slowest = 0
   }
 
   local all_symbols_stats = {}
+  local all_fps = {}
+  local all_fns = {}
 
   local false_positives = 0
   local false_negatives = 0
@@ -124,13 +129,15 @@ function utility.generate_statistics_from_logs(logs, threshold)
       true_positives = true_positives + 1
     elseif is_spam and (score < threshold) then
       false_negatives = false_negatives + 1
+      table.insert(all_fns, log[#log])
     elseif not is_spam and (score >= threshold) then
       false_positives = false_positives + 1
+      table.insert(all_fps, log[#log])
     else
       true_negatives = true_negatives + 1
     end
 
-    for i=4, #log do
+    for i=4, (#log-2) do
       if all_symbols_stats[log[i]] == nil then
         all_symbols_stats[log[i]] = {
           name = log[i],
@@ -151,6 +158,12 @@ function utility.generate_statistics_from_logs(logs, threshold)
         all_symbols_stats[log[i]].ham_hits =
         all_symbols_stats[log[i]].ham_hits + 1
       end
+
+      -- Find slowest message
+      if (tonumber(log[#log-1]) > tonumber(file_stats.slowest)) then
+          file_stats.slowest = tostring(tonumber(log[#log-1]))
+          file_stats.slowest_file = log[#log]
+      end
     end
   end
 
@@ -192,7 +205,7 @@ function utility.generate_statistics_from_logs(logs, threshold)
         (symbol_stats.spam_percent + symbol_stats.ham_percent)
   end
 
-  return file_stats, all_symbols_stats
+  return file_stats, all_symbols_stats, all_fps, all_fns
 end
 
 return utility