From 44f9078de3ea7698a7d10373b4821877a1e5fe23 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Tue, 11 Jun 2024 14:32:13 +0100 Subject: [Minor] Add timings --- lualib/rspamadm/classifier_test.lua | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) (limited to 'lualib/rspamadm') diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua index fff4be444..7bb9a22e6 100644 --- a/lualib/rspamadm/classifier_test.lua +++ b/lualib/rspamadm/classifier_test.lua @@ -111,7 +111,8 @@ local function classify_files(files) end -- Function to evaluate classifier performance -local function evaluate_results(results, spam_label, ham_label, known_spam_files, known_ham_files, total_cv_files) +local function evaluate_results(results, spam_label, ham_label, + known_spam_files, known_ham_files, total_cv_files, elapsed) local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0 for _, res in ipairs(results) do if res.result == spam_label then @@ -146,7 +147,8 @@ local function evaluate_results(results, spam_label, ham_label, known_spam_files print(string.format("%-20s %-10.2f", "Precision", precision)) print(string.format("%-20s %-10.2f", "Recall", recall)) print(string.format("%-20s %-10.2f", "F1 Score", f1_score)) - print(string.format("%-20s %-10.2f%%", "Classified", total / total_cv_files * 100)) + print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100)) + print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed)) end local function handler(args) @@ -168,11 +170,17 @@ local function handler(args) #train_spam, #cv_spam, #train_ham, #cv_ham)) if not opts.no_learning then -- Train classifier + local t, train_spam_time, train_ham_time print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns)) + t = rspamd_util.get_time() train_classifier(train_spam, "learn_spam") + train_spam_time = rspamd_util.get_time() - t print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns)) + t = rspamd_util.get_time() train_classifier(train_ham, "learn_ham") - print("Learning done") + train_ham_time = rspamd_util.get_time() - t + print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds", + #train_spam, train_spam_time, #train_ham, train_ham_time)) end -- Classify cross-validation files @@ -189,10 +197,16 @@ local function handler(args) print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns)) -- Get classification results + local t = rspamd_util.get_time() local results = classify_files(cv_files) + local elapsed = rspamd_util.get_time() - t -- Evaluate results - evaluate_results(results, "spam", "ham", known_spam_files, known_ham_files, #cv_files) + evaluate_results(results, "spam", "ham", + known_spam_files, + known_ham_files, + #cv_files, + elapsed) end -- cgit v1.2.3