summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-06-11 14:32:13 +0100
committerVsevolod Stakhov <vsevolod@rspamd.com>2024-06-11 14:32:13 +0100
commit44f9078de3ea7698a7d10373b4821877a1e5fe23 (patch)
tree1828480ba42e3e94ae5e1dbaaab114d018358d16
parent5e780446ad1b037eaad7acef9d073be0e65dada4 (diff)
downloadrspamd-44f9078de3ea7698a7d10373b4821877a1e5fe23.tar.gz
rspamd-44f9078de3ea7698a7d10373b4821877a1e5fe23.zip
[Minor] Add timings
-rw-r--r--lualib/rspamadm/classifier_test.lua22
1 files changed, 18 insertions, 4 deletions
diff --git a/lualib/rspamadm/classifier_test.lua b/lualib/rspamadm/classifier_test.lua
index fff4be444..7bb9a22e6 100644
--- a/lualib/rspamadm/classifier_test.lua
+++ b/lualib/rspamadm/classifier_test.lua
@@ -111,7 +111,8 @@ local function classify_files(files)
end
-- Function to evaluate classifier performance
-local function evaluate_results(results, spam_label, ham_label, known_spam_files, known_ham_files, total_cv_files)
+local function evaluate_results(results, spam_label, ham_label,
+ known_spam_files, known_ham_files, total_cv_files, elapsed)
local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
for _, res in ipairs(results) do
if res.result == spam_label then
@@ -146,7 +147,8 @@ local function evaluate_results(results, spam_label, ham_label, known_spam_files
print(string.format("%-20s %-10.2f", "Precision", precision))
print(string.format("%-20s %-10.2f", "Recall", recall))
print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
- print(string.format("%-20s %-10.2f%%", "Classified", total / total_cv_files * 100))
+ print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100))
+ print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed))
end
local function handler(args)
@@ -168,11 +170,17 @@ local function handler(args)
#train_spam, #cv_spam, #train_ham, #cv_ham))
if not opts.no_learning then
-- Train classifier
+ local t, train_spam_time, train_ham_time
print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
+ t = rspamd_util.get_time()
train_classifier(train_spam, "learn_spam")
+ train_spam_time = rspamd_util.get_time() - t
print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
+ t = rspamd_util.get_time()
train_classifier(train_ham, "learn_ham")
- print("Learning done")
+ train_ham_time = rspamd_util.get_time() - t
+ print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds",
+ #train_spam, train_spam_time, #train_ham, train_ham_time))
end
-- Classify cross-validation files
@@ -189,10 +197,16 @@ local function handler(args)
print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
-- Get classification results
+ local t = rspamd_util.get_time()
local results = classify_files(cv_files)
+ local elapsed = rspamd_util.get_time() - t
-- Evaluate results
- evaluate_results(results, "spam", "ham", known_spam_files, known_ham_files, #cv_files)
+ evaluate_results(results, "spam", "ham",
+ known_spam_files,
+ known_ham_files,
+ #cv_files,
+ elapsed)
end