local rspamd_util = require "rspamd_util"
local lua_util = require "lua_util"
local argparse = require "argparse"
local ucl = require "ucl"
local rspamd_logger = require "rspamd_logger"
local parser = argparse()
:name "rspamadm classifier_test"
:description "Learn bayes classifier and evaluate its performance"
:help_description_margin(32)
parser:option "-H --ham"
:description("Ham directory")
:argname("
")
parser:option "-S --spam"
:description("Spam directory")
:argname("")
parser:flag "-n --no-learning"
:description("Do not learn classifier")
parser:option "--nconns"
:description("Number of parallel connections")
:argname("")
:convert(tonumber)
:default(10)
parser:option "-t --timeout"
:description("Timeout for client connections")
:argname("")
:convert(tonumber)
:default(60)
parser:option "-c --connect"
:description("Connect to specific host")
:argname("")
:default('localhost:11334')
parser:option "-r --rspamc"
:description("Use specific rspamc path")
:argname("")
:default('rspamc')
parser:option "-c --cv-fraction"
:description("Use specific fraction for cross-validation")
:argname("")
:convert(tonumber)
:default('0.7')
parser:option "--spam-symbol"
:description("Use specific spam symbol (instead of BAYES_SPAM)")
:argname("")
:default('BAYES_SPAM')
parser:option "--ham-symbol"
:description("Use specific ham symbol (instead of BAYES_HAM)")
:argname("")
:default('BAYES_HAM')
local opts
-- Utility function to split a table into two parts randomly
local function split_table(t, fraction)
local shuffled = {}
for _, v in ipairs(t) do
local pos = math.random(1, #shuffled + 1)
table.insert(shuffled, pos, v)
end
local split_point = math.floor(#shuffled * tonumber(fraction))
local part1 = { lua_util.unpack(shuffled, 1, split_point) }
local part2 = { lua_util.unpack(shuffled, split_point + 1) }
return part1, part2
end
-- Utility function to get all files in a directory
local function get_files(dir)
return rspamd_util.glob(dir .. '/*')
end
local function list_to_file(list, fname)
local out = assert(io.open(fname, "w"))
for _, v in ipairs(list) do
out:write(v)
out:write("\n")
end
out:close()
end
-- Function to train the classifier with given files
local function train_classifier(files, command)
local fname = os.tmpname()
list_to_file(files, fname)
local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s --files-list=%s",
opts.rspamc, opts.connect, opts.nconns, opts.timeout, command, fname)
local result = assert(io.popen(rspamc_command))
result = result:read("*all")
os.remove(fname)
end
-- Function to classify files and return results
local function classify_files(files, known_spam_files, known_ham_files)
local fname = os.tmpname()
list_to_file(files, fname)
local settings_header = string.format('--header Settings=\"{symbols_enabled=[%s, %s]}\"',
opts.spam_symbol, opts.ham_symbol)
local rspamc_command = string.format("%s %s --connect %s --compact -n %s -t %.3f --files-list=%s",
opts.rspamc,
settings_header,
opts.connect,
opts.nconns,
opts.timeout, fname)
local result = assert(io.popen(rspamc_command))
local results = {}
for line in result:lines() do
local ucl_parser = ucl.parser()
local is_good, err = ucl_parser:parse_string(line)
if not is_good then
rspamd_logger.errx("Parser error: %1", err)
os.remove(fname)
return nil
end
local obj = ucl_parser:get_object()
local file = obj.filename
local symbols = obj.symbols or {}
if symbols[opts.spam_symbol] then
table.insert(results, { result = "spam", file = file })
if known_ham_files[file] then
rspamd_logger.message("FP: %s is classified as spam but is known ham", file)
end
elseif symbols[opts.ham_symbol] then
if known_spam_files[file] then
rspamd_logger.message("FN: %s is classified as ham but is known spam", file)
end
table.insert(results, { result = "ham", file = file })
end
end
os.remove(fname)
return results
end
-- Function to evaluate classifier performance
local function evaluate_results(results, spam_label, ham_label,
known_spam_files, known_ham_files, total_cv_files, elapsed)
local true_positives, false_positives, true_negatives, false_negatives, total = 0, 0, 0, 0, 0
for _, res in ipairs(results) do
if res.result == spam_label then
if known_spam_files[res.file] then
true_positives = true_positives + 1
elseif known_ham_files[res.file] then
false_positives = false_positives + 1
end
total = total + 1
elseif res.result == ham_label then
if known_spam_files[res.file] then
false_negatives = false_negatives + 1
elseif known_ham_files[res.file] then
true_negatives = true_negatives + 1
end
total = total + 1
end
end
local accuracy = (true_positives + true_negatives) / total
local precision = true_positives / (true_positives + false_positives)
local recall = true_positives / (true_positives + false_negatives)
local f1_score = 2 * (precision * recall) / (precision + recall)
print(string.format("%-20s %-10s", "Metric", "Value"))
print(string.rep("-", 30))
print(string.format("%-20s %-10d", "True Positives", true_positives))
print(string.format("%-20s %-10d", "False Positives", false_positives))
print(string.format("%-20s %-10d", "True Negatives", true_negatives))
print(string.format("%-20s %-10d", "False Negatives", false_negatives))
print(string.format("%-20s %-10.2f", "Accuracy", accuracy))
print(string.format("%-20s %-10.2f", "Precision", precision))
print(string.format("%-20s %-10.2f", "Recall", recall))
print(string.format("%-20s %-10.2f", "F1 Score", f1_score))
print(string.format("%-20s %-10.2f", "Classified (%)", total / total_cv_files * 100))
print(string.format("%-20s %-10.2f", "Elapsed time (seconds)", elapsed))
end
local function handler(args)
opts = parser:parse(args)
local ham_directory = opts['ham']
local spam_directory = opts['spam']
-- Get all files
local spam_files = get_files(spam_directory)
local known_spam_files = lua_util.list_to_hash(spam_files)
local ham_files = get_files(ham_directory)
local known_ham_files = lua_util.list_to_hash(ham_files)
-- Split files into training and cross-validation sets
local train_spam, cv_spam = split_table(spam_files, opts.cv_fraction)
local train_ham, cv_ham = split_table(ham_files, opts.cv_fraction)
print(string.format("Spam: %d train files, %d cv files; ham: %d train files, %d cv files",
#train_spam, #cv_spam, #train_ham, #cv_ham))
if not opts.no_learning then
-- Train classifier
local t, train_spam_time, train_ham_time
print(string.format("Start learn spam, %d messages, %d connections", #train_spam, opts.nconns))
t = rspamd_util.get_time()
train_classifier(train_spam, "learn_spam")
train_spam_time = rspamd_util.get_time() - t
print(string.format("Start learn ham, %d messages, %d connections", #train_ham, opts.nconns))
t = rspamd_util.get_time()
train_classifier(train_ham, "learn_ham")
train_ham_time = rspamd_util.get_time() - t
print(string.format("Learning done: %d spam messages in %.2f seconds, %d ham messages in %.2f seconds",
#train_spam, train_spam_time, #train_ham, train_ham_time))
end
-- Classify cross-validation files
local cv_files = {}
for _, file in ipairs(cv_spam) do
table.insert(cv_files, file)
end
for _, file in ipairs(cv_ham) do
table.insert(cv_files, file)
end
-- Shuffle cross-validation files
cv_files = split_table(cv_files, 1)
print(string.format("Start cross validation, %d messages, %d connections", #cv_files, opts.nconns))
-- Get classification results
local t = rspamd_util.get_time()
local results = classify_files(cv_files, known_spam_files, known_ham_files)
local elapsed = rspamd_util.get_time() - t
-- Evaluate results
evaluate_results(results, "spam", "ham",
known_spam_files,
known_ham_files,
#cv_files,
elapsed)
end
return {
name = 'classifiertest',
aliases = { 'classifier_test' },
handler = handler,
description = parser._description
}