]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add a tool to perform statistical analysis of classifiers
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 10 Jun 2024 15:10:26 +0000 (16:10 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 10 Jun 2024 15:10:26 +0000 (16:10 +0100)
lualib/rspamadm/classifier_test.lua [new file with mode: 0644]

diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua
new file mode 100644 (file)
index 0000000..44d2fc9
--- /dev/null
@@ -0,0 +1,186 @@
+local rspamd_util = require "rspamd_util"
+local lua_util = require "lua_util"
+local argparse = require "argparse"
+local fun = require "fun"
+
+local parser = argparse()
+    :name "rspamadm classifier_test"
+    :description "Learn bayes classifier and evaluate its performance"
+    :help_description_margin(32)
+
+parser:option "-H --ham"
+      :description("Ham directory")
+      :argname("<dir>")
+parser:option "-S --spam"
+      :description("Spam directory")
+      :argname("<dir>")
+parser:flag "-n --no-learning"
+      :description("Do not learn classifier")
+parser:option "--nconns"
+      :description("Number of parallel connections")
+      :argname("<N>")
+      :convert(tonumber)
+      :default(10)
+parser:option "-t --timeout"
+      :description("Timeout for client connections")
+      :argname("<sec>")
+      :convert(tonumber)
+      :default(60)
+parser:option "-c --connect"
+      :description("Connect to specific host")
+      :argname("<host>")
+      :default('localhost:11334')
+parser:option "-r --rspamc"
+      :description("Use specific rspamc path")
+      :argname("<path>")
+      :default('rspamc')
+parser:option "-c --cv-fraction"
+      :description("Use specific fraction for cross-validation")
+      :argname("<fraction>")
+      :convert(tonumber)
+      :default('0.7')
+
+local opts
+
+-- Utility function to split a table into two parts randomly
+local function split_table(t, fraction)
+  local shuffled = {}
+  for _, v in ipairs(t) do
+    local pos = math.random(1, #shuffled + 1)
+    table.insert(shuffled, pos, v)
+  end
+  local split_point = math.floor(#shuffled * tonumber(fraction))
+  local part1 = { lua_util.unpack(shuffled, 1, split_point) }
+  local part2 = { lua_util.unpack(shuffled, split_point + 1) }
+  return part1, part2
+end
+
+local function shell_quote(argument)
+  if argument:match('^[%w%+%-%.,:/=@_]+$') then
+    return argument
+  end
+  argument = argument:gsub('[$`"\\]', '\\%0')
+  return '"' .. argument .. '"'
+end
+
+-- Utility function to get all files in a directory
+local function get_files(dir)
+  return fun.totable(fun.map(shell_quote, rspamd_util.glob(dir .. '/*')))
+end
+
+-- Function to train the classifier with given files
+local function train_classifier(files, command, connections)
+  local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s %s",
+      opts.rspamc, opts.connect, opts.nconns, opts.timeout, command, table.concat(files, " "))
+  local result = assert(io.popen(rspamc_command))
+  result = result:read("*all")
+end
+
+-- Function to classify files and return results
+local function classify_files(files)
+  local settings_header = '--header Settings=\"{symbols_enabled=[BAYES_SPAM, BAYES_HAM]}\"'
+  local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f %s",
+      opts.rspamc,
+      settings_header,
+      opts.connect,
+      opts.nconns,
+      opts.timeout, table.concat(files, " "))
+  local result = assert(io.popen(rspamc_command))
+  local results = {}
+  for line in result:lines() do
+    if string.match(line, "BAYES_SPAM") then
+      table.insert(results, { result = "spam", output = line })
+    elseif string.match(line, "BAYES_HAM") then
+      table.insert(results, { result = "ham", output = line })
+    end
+  end
+
+  return results
+end
+
+-- Function to evaluate classifier performance
+local function evaluate_results(results, true_label)
+  local true_positives, false_positives, true_negatives, false_negatives = 0, 0, 0, 0
+  for _, res in ipairs(results) do
+    if res.result == true_label then
+      if string.match(res.file, true_label) then
+        true_positives = true_positives + 1
+      else
+        false_positives = false_positives + 1
+      end
+    else
+      if string.match(res.file, true_label) then
+        false_negatives = false_negatives + 1
+      else
+        true_negatives = true_negatives + 1
+      end
+    end
+  end
+
+  local total = #results
+  local accuracy = (true_positives + true_negatives) / total
+  local precision = true_positives / (true_positives + false_positives)
+  local recall = true_positives / (true_positives + false_negatives)
+  local f1_score = 2 * (precision * recall) / (precision + recall)
+
+  print("True Positives:", true_positives)
+  print("False Positives:", false_positives)
+  print("True Negatives:", true_negatives)
+  print("False Negatives:", false_negatives)
+  print("Accuracy:", accuracy)
+  print("Precision:", precision)
+  print("Recall:", recall)
+  print("F1 Score:", f1_score)
+end
+
+local function handler(args)
+  opts = parser:parse(args)
+  local ham_directory = opts['ham']
+  local spam_directory = opts['spam']
+  -- Get all files
+  local spam_files = get_files(spam_directory)
+  local ham_files = get_files(ham_directory)
+
+  -- Split files into training and cross-validation sets
+
+  local train_spam, cv_spam = split_table(spam_files, opts.cv_fraction)
+  local train_ham, cv_ham = split_table(ham_files, opts.cv_fraction)
+
+  print(string.format("Spam: %d train files, %d cv files; ham: %d train files, %d cv files",
+      #train_spam, #cv_spam, #train_ham, #cv_ham))
+  if not opts.no_learning then
+    -- Train classifier
+    print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
+    train_classifier(train_spam, "learn_spam")
+    print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
+    train_classifier(train_ham, "learn_ham")
+    print("Learning done")
+  end
+
+  -- Classify cross-validation files
+  local cv_files = {}
+  for _, file in ipairs(cv_spam) do
+    table.insert(cv_files, file)
+  end
+  for _, file in ipairs(cv_ham) do
+    table.insert(cv_files, file)
+  end
+
+  -- Shuffle cross-validation files
+  cv_files = split_table(cv_files, 1)
+
+  print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
+  -- Get classification results
+  local results = classify_files(cv_files)
+
+  -- Evaluate results
+  evaluate_results(results, "spam")
+
+end
+
+return {
+  name = 'classifiertest',
+  aliases = { 'classifier_test' },
+  handler = handler,
+  description = parser._description
+}
\ No newline at end of file