aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/classifier_test.lua62
1 files changed, 40 insertions, 22 deletions
diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua
index 44d2fc9e6..fff4be444 100644
--- a/lualib/rspamadm/classifier_test.lua
+++ b/lualib/rspamadm/classifier_test.lua
@@ -2,6 +2,8 @@ local rspamd_util = require "rspamd_util"
local lua_util = require "lua_util"
local argparse = require "argparse"
local fun = require "fun"
+local ucl = require "ucl"
+local rspamd_logger = require "rspamd_logger"
local parser = argparse()
:name "rspamadm classifier_test"
@@ -88,10 +90,20 @@ local function classify_files(files)
local result = assert(io.popen(rspamc_command))
local results = {}
for line in result:lines() do
- if string.match(line, "BAYES_SPAM") then
- table.insert(results, { result = "spam", output = line })
- elseif string.match(line, "BAYES_HAM") then
- table.insert(results, { result = "ham", output = line })
+ local ucl_parser = ucl.parser()
+ local is_good, err = ucl_parser:parse_string(line)
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ return nil
+ end
+ local obj = ucl_parser:get_object()
+ local file = obj.filename
+ local symbols = obj.symbols or {}
+
+ if symbols["BAYES_SPAM"] then
+ table.insert(results, { result = "spam", file = file })
+ elseif symbols["BAYES_HAM"] then
+ table.insert(results, { result = "ham", file = file })
end
end
@@ -99,38 +111,42 @@ local function classify_files(files)
end
-- Function to evaluate classifier performance
-local function evaluate_results(results, true_label)
- local true_positives, false_positives, true_negatives, false_negatives = 0, 0, 0, 0
+local function evaluate_results(results, spam_label, ham_label, known_spam_files, known_ham_files, total_cv_files)
+ local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
for _, res in ipairs(results) do
- if res.result == true_label then
- if string.match(res.file, true_label) then
+ if res.result == spam_label then
+ if known_spam_files[res.file] then
true_positives = true_positives + 1
- else
+ elseif known_ham_files[res.file] then
false_positives = false_positives + 1
end
- else
- if string.match(res.file, true_label) then
+ total = total + 1
+ elseif res.result == ham_label then
+ if known_spam_files[res.file] then
false_negatives = false_negatives + 1
- else
+ elseif known_ham_files[res.file] then
true_negatives = true_negatives + 1
end
+ total = total + 1
end
end
- local total = #results
local accuracy = (true_positives + true_negatives) / total
local precision = true_positives / (true_positives + false_positives)
local recall = true_positives / (true_positives + false_negatives)
local f1_score = 2 * (precision * recall) / (precision + recall)
- print("True Positives:", true_positives)
- print("False Positives:", false_positives)
- print("True Negatives:", true_negatives)
- print("False Negatives:", false_negatives)
- print("Accuracy:", accuracy)
- print("Precision:", precision)
- print("Recall:", recall)
- print("F1 Score:", f1_score)
+ print(string.format("%-20s %-10s", "Metric", "Value"))
+ print(string.rep("-", 30))
+ print(string.format("%-20s %-10d", "True Positives", true_positives))
+ print(string.format("%-20s %-10d", "False Positives", false_positives))
+ print(string.format("%-20s %-10d", "True Negatives", true_negatives))
+ print(string.format("%-20s %-10d", "False Negatives", false_negatives))
+ print(string.format("%-20s %-10.2f", "Accuracy", accuracy))
+ print(string.format("%-20s %-10.2f", "Precision", precision))
+ print(string.format("%-20s %-10.2f", "Recall", recall))
+ print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
+ print(string.format("%-20s %-10.2f%%", "Classified", total / total_cv_files * 100))
end
local function handler(args)
@@ -139,7 +155,9 @@ local function handler(args)
local spam_directory = opts['spam']
-- Get all files
local spam_files = get_files(spam_directory)
+ local known_spam_files = lua_util.list_to_hash(spam_files)
local ham_files = get_files(ham_directory)
+ local known_ham_files = lua_util.list_to_hash(ham_files)
-- Split files into training and cross-validation sets
@@ -174,7 +192,7 @@ local function handler(args)
local results = classify_files(cv_files)
-- Evaluate results
- evaluate_results(results, "spam")
+ evaluate_results(results, "spam", "ham", known_spam_files, known_ham_files, #cv_files)
end