]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Improve stats processing
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 10 Jun 2024 18:04:37 +0000 (19:04 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 10 Jun 2024 18:04:37 +0000 (19:04 +0100)
lualib/rspamadm/classifier_test.lua

index 44d2fc9e686460037108686d05296a7477fd4239..fff4be444035c4d3bfd026b7da37fbd8a8f74ece 100644 (file)
@@ -2,6 +2,8 @@ local rspamd_util = require "rspamd_util"
 local lua_util = require "lua_util"
 local argparse = require "argparse"
 local fun = require "fun"
+local ucl = require "ucl"
+local rspamd_logger = require "rspamd_logger"
 
 local parser = argparse()
     :name "rspamadm classifier_test"
@@ -88,10 +90,20 @@ local function classify_files(files)
   local result = assert(io.popen(rspamc_command))
   local results = {}
   for line in result:lines() do
-    if string.match(line, "BAYES_SPAM") then
-      table.insert(results, { result = "spam", output = line })
-    elseif string.match(line, "BAYES_HAM") then
-      table.insert(results, { result = "ham", output = line })
+    local ucl_parser = ucl.parser()
+    local is_good, err = ucl_parser:parse_string(line)
+    if not is_good then
+      rspamd_logger.errx("Parser error: %1", err)
+      return nil
+    end
+    local obj = ucl_parser:get_object()
+    local file = obj.filename
+    local symbols = obj.symbols or {}
+
+    if symbols["BAYES_SPAM"] then
+      table.insert(results, { result = "spam", file = file })
+    elseif symbols["BAYES_HAM"] then
+      table.insert(results, { result = "ham", file = file })
     end
   end
 
@@ -99,38 +111,42 @@ local function classify_files(files)
 end
 
 -- Function to evaluate classifier performance
-local function evaluate_results(results, true_label)
-  local true_positives, false_positives, true_negatives, false_negatives = 0, 0, 0, 0
+local function evaluate_results(results, spam_label, ham_label, known_spam_files, known_ham_files, total_cv_files)
+  local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
   for _, res in ipairs(results) do
-    if res.result == true_label then
-      if string.match(res.file, true_label) then
+    if res.result == spam_label then
+      if known_spam_files[res.file] then
         true_positives = true_positives + 1
-      else
+      elseif known_ham_files[res.file] then
         false_positives = false_positives + 1
       end
-    else
-      if string.match(res.file, true_label) then
+      total = total + 1
+    elseif res.result == ham_label then
+      if known_spam_files[res.file] then
         false_negatives = false_negatives + 1
-      else
+      elseif known_ham_files[res.file] then
         true_negatives = true_negatives + 1
       end
+      total = total + 1
     end
   end
 
-  local total = #results
   local accuracy = (true_positives + true_negatives) / total
   local precision = true_positives / (true_positives + false_positives)
   local recall = true_positives / (true_positives + false_negatives)
   local f1_score = 2 * (precision * recall) / (precision + recall)
 
-  print("True Positives:", true_positives)
-  print("False Positives:", false_positives)
-  print("True Negatives:", true_negatives)
-  print("False Negatives:", false_negatives)
-  print("Accuracy:", accuracy)
-  print("Precision:", precision)
-  print("Recall:", recall)
-  print("F1 Score:", f1_score)
+  print(string.format("%-20s %-10s", "Metric", "Value"))
+  print(string.rep("-", 30))
+  print(string.format("%-20s %-10d", "True Positives", true_positives))
+  print(string.format("%-20s %-10d", "False Positives", false_positives))
+  print(string.format("%-20s %-10d", "True Negatives", true_negatives))
+  print(string.format("%-20s %-10d", "False Negatives", false_negatives))
+  print(string.format("%-20s %-10.2f", "Accuracy", accuracy))
+  print(string.format("%-20s %-10.2f", "Precision", precision))
+  print(string.format("%-20s %-10.2f", "Recall", recall))
+  print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
+  print(string.format("%-20s %-10.2f%%", "Classified", total / total_cv_files * 100))
 end
 
 local function handler(args)
@@ -139,7 +155,9 @@ local function handler(args)
   local spam_directory = opts['spam']
   -- Get all files
   local spam_files = get_files(spam_directory)
+  local known_spam_files = lua_util.list_to_hash(spam_files)
   local ham_files = get_files(ham_directory)
+  local known_ham_files = lua_util.list_to_hash(ham_files)
 
   -- Split files into training and cross-validation sets
 
@@ -174,7 +192,7 @@ local function handler(args)
   local results = classify_files(cv_files)
 
   -- Evaluate results
-  evaluate_results(results, "spam")
+  evaluate_results(results, "spam", "ham", known_spam_files, known_ham_files, #cv_files)
 
 end