diff options
Diffstat (limited to 'lualib/rspamadm/classifier_test.lua')
-rw-r--r-- | lualib/rspamadm/classifier_test.lua | 62 |
1 files changed, 40 insertions, 22 deletions
diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua index 44d2fc9e6..fff4be444 100644 --- a/lualib/rspamadm/classifier_test.lua +++ b/lualib/rspamadm/classifier_test.lua @@ -2,6 +2,8 @@ local rspamd_util = require "rspamd_util" local lua_util = require "lua_util" local argparse = require "argparse" local fun = require "fun" +local ucl = require "ucl" +local rspamd_logger = require "rspamd_logger" local parser = argparse() :name "rspamadm classifier_test" @@ -88,10 +90,20 @@ local function classify_files(files) local result = assert(io.popen(rspamc_command)) local results = {} for line in result:lines() do - if string.match(line, "BAYES_SPAM") then - table.insert(results, { result = "spam", output = line }) - elseif string.match(line, "BAYES_HAM") then - table.insert(results, { result = "ham", output = line }) + local ucl_parser = ucl.parser() + local is_good, err = ucl_parser:parse_string(line) + if not is_good then + rspamd_logger.errx("Parser error: %1", err) + return nil + end + local obj = ucl_parser:get_object() + local file = obj.filename + local symbols = obj.symbols or {} + + if symbols["BAYES_SPAM"] then + table.insert(results, { result = "spam", file = file }) + elseif symbols["BAYES_HAM"] then + table.insert(results, { result = "ham", file = file }) end end @@ -99,38 +111,42 @@ local function classify_files(files) end -- Function to evaluate classifier performance -local function evaluate_results(results, true_label) - local true_positives, false_positives, true_negatives, false_negatives = 0, 0, 0, 0 +local function evaluate_results(results, spam_label, ham_label, known_spam_files, known_ham_files, total_cv_files) + local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0 for _, res in ipairs(results) do - if res.result == true_label then - if string.match(res.file, true_label) then + if res.result == spam_label then + if known_spam_files[res.file] then true_positives = true_positives + 1 - else + elseif known_ham_files[res.file] then false_positives = false_positives + 1 end - else - if string.match(res.file, true_label) then + total = total + 1 + elseif res.result == ham_label then + if known_spam_files[res.file] then false_negatives = false_negatives + 1 - else + elseif known_ham_files[res.file] then true_negatives = true_negatives + 1 end + total = total + 1 end end - local total = #results local accuracy = (true_positives + true_negatives) / total local precision = true_positives / (true_positives + false_positives) local recall = true_positives / (true_positives + false_negatives) local f1_score = 2 * (precision * recall) / (precision + recall) - print("True Positives:", true_positives) - print("False Positives:", false_positives) - print("True Negatives:", true_negatives) - print("False Negatives:", false_negatives) - print("Accuracy:", accuracy) - print("Precision:", precision) - print("Recall:", recall) - print("F1 Score:", f1_score) + print(string.format("%-20s %-10s", "Metric", "Value")) + print(string.rep("-", 30)) + print(string.format("%-20s %-10d", "True Positives", true_positives)) + print(string.format("%-20s %-10d", "False Positives", false_positives)) + print(string.format("%-20s %-10d", "True Negatives", true_negatives)) + print(string.format("%-20s %-10d", "False Negatives", false_negatives)) + print(string.format("%-20s %-10.2f", "Accuracy", accuracy)) + print(string.format("%-20s %-10.2f", "Precision", precision)) + print(string.format("%-20s %-10.2f", "Recall", recall)) + print(string.format("%-20s %-10.2f", "F1 Score", f1_score)) + print(string.format("%-20s %-10.2f%%", "Classified", total / total_cv_files * 100)) end local function handler(args) @@ -139,7 +155,9 @@ local function handler(args) local spam_directory = opts['spam'] -- Get all files local spam_files = get_files(spam_directory) + local known_spam_files = lua_util.list_to_hash(spam_files) local ham_files = get_files(ham_directory) + local known_ham_files = lua_util.list_to_hash(ham_files) -- Split files into training and cross-validation sets @@ -174,7 +192,7 @@ local function handler(args) local results = classify_files(cv_files) -- Evaluate results - evaluate_results(results, "spam") + evaluate_results(results, "spam", "ham", known_spam_files, known_ham_files, #cv_files) end |