local rspamd_logger = require "rspamd_logger" local ucl = require "ucl" local lua_util = require "lua_util" local argparse = require "argparse" local parser = argparse() :name "rspamadm corpus_test" :description "Create logs files from email corpus" :help_description_margin(32) parser:option "-H --ham" :description("Ham directory") :argname("") parser:option "-S --spam" :description("Spam directory") :argname("") parser:option "-n --conns" :description("Number of parallel connections") :argname("") :convert(tonumber) :default(10) parser:option "-o --output" :description("Output file") :argname("") :default('results.log') parser:option "-t --timeout" :description("Timeout for client connections") :argname("") :convert(tonumber) :default(60) parser:option "-c --connect" :description("Connect to specific host") :argname("") :default('localhost:11334') parser:option "-r --rspamc" :description("Use specific rspamc path") :argname("") :default('rspamc') local HAM = "HAM" local SPAM = "SPAM" local opts local function scan_email(n_parallel, path, timeout) local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s", opts.rspamc, opts.connect, n_parallel, timeout, path) local result = assert(io.popen(rspamc_command)) result = result:read("*all") return result end local function write_results(results, file) local f = io.open(file, 'w') for _, result in pairs(results) do local log_line = string.format("%s %.2f %s", result.type, result.score, result.action) for _, sym in pairs(result.symbols) do log_line = log_line .. " " .. sym end log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename log_line = log_line .. "\r\n" f:write(log_line) end f:close() end local function encoded_json_to_log(result) -- Returns table containing score, action, list of symbols local filtered_result = {} local ucl_parser = ucl.parser() local is_good, err = ucl_parser:parse_string(result) if not is_good then rspamd_logger.errx("Parser error: %1", err) return nil end result = ucl_parser:get_object() filtered_result.score = result.score if not result.action then rspamd_logger.errx("Bad JSON: %1", result) return nil end local action = result.action:gsub("%s+", "_") filtered_result.action = action filtered_result.symbols = {} for sym, _ in pairs(result.symbols) do table.insert(filtered_result.symbols, sym) end filtered_result.filename = result.filename filtered_result.scan_time = result.scan_time return filtered_result end local function scan_results_to_logs(results, actual_email_type) local logs = {} results = lua_util.rspamd_str_split(results, "\n") if results[#results] == "" then results[#results] = nil end for _, result in pairs(results) do result = encoded_json_to_log(result) if result then result['type'] = actual_email_type table.insert(logs, result) end end return logs end local function handler(args) opts = parser:parse(args) local ham_directory = opts['ham'] local spam_directory = opts['spam'] local connections = opts["conns"] local output = opts["output"] local results = {} local start_time = os.time() local no_of_ham = 0 local no_of_spam = 0 if ham_directory then rspamd_logger.messagex("Scanning ham corpus...") local ham_results = scan_email(connections, ham_directory, opts["timeout"]) ham_results = scan_results_to_logs(ham_results, HAM) no_of_ham = #ham_results for _, result in pairs(ham_results) do table.insert(results, result) end end if spam_directory then rspamd_logger.messagex("Scanning spam corpus...") local spam_results = scan_email(connections, spam_directory, opts.timeout) spam_results = scan_results_to_logs(spam_results, SPAM) no_of_spam = #spam_results for _, result in pairs(spam_results) do table.insert(results, result) end end rspamd_logger.messagex("Writing results to %s", output) write_results(results, output) rspamd_logger.messagex("Stats: ") local elapsed_time = os.time() - start_time local total_msgs = no_of_ham + no_of_spam rspamd_logger.messagex("Elapsed time: %ss", elapsed_time) rspamd_logger.messagex("No of ham: %s", no_of_ham) rspamd_logger.messagex("No of spam: %s", no_of_spam) rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time)) end return { name = 'corpus_test', handler = handler, description = parser._description }