aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-04-24 10:48:44 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-04-24 10:48:44 +0100
commitd173e8d13ed53e9562f472c4f670797e9cfc78b9 (patch)
treea0b5db4583aed6d3f75fbab85a4388c139a0d391 /lualib/rspamadm
parent38f05a24a5da2c3692fa41a28aad6bc767a08b92 (diff)
downloadrspamd-d173e8d13ed53e9562f472c4f670797e9cfc78b9.tar.gz
rspamd-d173e8d13ed53e9562f472c4f670797e9cfc78b9.zip
[Minor] Various improvements to corpus_test script
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/corpus_test.lua184
1 files changed, 95 insertions, 89 deletions
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
index b71f96e9e..cd9f66155 100644
--- a/lualib/rspamadm/corpus_test.lua
+++ b/lualib/rspamadm/corpus_test.lua
@@ -1,141 +1,147 @@
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
local lua_util = require "lua_util"
+local getopt = require "rspamadm/getopt"
local HAM = "HAM"
local SPAM = "SPAM"
+local opts
+local default_opts = {
+ connect = 'localhost:11334',
+}
local function scan_email(n_parallel, path, timeout)
- local rspamc_command = string.format("rspamc -j --compact -n %s -t %.3f %s",
- n_parallel, timeout, path)
- local result = assert(io.popen(rspamc_command))
- result = result:read("*all")
- return result
-end
+ local rspamc_command = string.format("rspamc --connect %s -j --compact -n %s -t %.3f %s",
+ opts.connect, n_parallel, timeout, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
local function write_results(results, file)
- local f = io.open(file, 'w')
+ local f = io.open(file, 'w')
- for _, result in pairs(results) do
- local log_line = string.format("%s %.2f %s", result.type, result.score, result.action)
+ for _, result in pairs(results) do
+ local log_line = string.format("%s %.2f %s", result.type, result.score, result.action)
- for _, sym in pairs(result.symbols) do
- log_line = log_line .. " " .. sym
- end
+ for _, sym in pairs(result.symbols) do
+ log_line = log_line .. " " .. sym
+ end
- log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename
+ log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename
- log_line = log_line .. "\r\n"
+ log_line = log_line .. "\r\n"
- f:write(log_line)
- end
+ f:write(log_line)
+ end
- f:close()
+ f:close()
end
local function encoded_json_to_log(result)
- -- Returns table containing score, action, list of symbols
+ -- Returns table containing score, action, list of symbols
- local filtered_result = {}
- local parser = ucl.parser()
+ local filtered_result = {}
+ local parser = ucl.parser()
- local is_good, err = parser:parse_string(result)
+ local is_good, err = parser:parse_string(result)
- if not is_good then
- io.stderr:write(rspamd_logger.slog("Parser error: %1\n", err))
- return nil
- end
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ return nil
+ end
- result = parser:get_object()
+ result = parser:get_object()
- filtered_result.score = result.score
- if not result.action then
- io.stderr:write(rspamd_logger.slog("Bad JSON: %1\n", result))
- return nil
- end
- local action = result.action:gsub("%s+", "_")
- filtered_result.action = action
+ filtered_result.score = result.score
+ if not result.action then
+ rspamd_logger.errx("Bad JSON: %1", result)
+ return nil
+ end
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
- filtered_result.symbols = {}
+ filtered_result.symbols = {}
- for sym, _ in pairs(result.symbols) do
- table.insert(filtered_result.symbols, sym)
- end
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
- filtered_result.filename = result.filename
- filtered_result.scan_time = result.scan_time
+ filtered_result.filename = result.filename
+ filtered_result.scan_time = result.scan_time
- return filtered_result
+ return filtered_result
end
local function scan_results_to_logs(results, actual_email_type)
- local logs = {}
+ local logs = {}
- results = lua_util.rspamd_str_split(results, "\n")
+ results = lua_util.rspamd_str_split(results, "\n")
- if results[#results] == "" then
- results[#results] = nil
- end
+ if results[#results] == "" then
+ results[#results] = nil
+ end
- for _, result in pairs(results) do
- result = encoded_json_to_log(result)
- if result then
- result['type'] = actual_email_type
- table.insert(logs, result)
- end
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ if result then
+ result['type'] = actual_email_type
+ table.insert(logs, result)
end
+ end
- return logs
+ return logs
end
-return function (_, res)
-
- local ham_directory = res['ham_directory']
- local spam_directory = res['spam_directory']
- local connections = res["connections"]
- local output = res["output_location"]
+return function(args, res)
+ opts = default_opts
+ opts = lua_util.override_defaults(opts, getopt.getopt(args, ''))
+ local ham_directory = res['ham_directory']
+ local spam_directory = res['spam_directory']
+ local connections = res["connections"]
+ local output = res["output_location"]
- local results = {}
+ local results = {}
- local start_time = os.time()
- local no_of_ham = 0
- local no_of_spam = 0
+ local start_time = os.time()
+ local no_of_ham = 0
+ local no_of_spam = 0
- if ham_directory then
- io.write("Scanning ham corpus...\n")
- local ham_results = scan_email(connections, ham_directory, res["timeout"])
- ham_results = scan_results_to_logs(ham_results, HAM)
+ if ham_directory then
+ rspamd_logger.messagex("Scanning ham corpus...")
+ local ham_results = scan_email(connections, ham_directory, res["timeout"])
+ ham_results = scan_results_to_logs(ham_results, HAM)
- no_of_ham = #ham_results
+ no_of_ham = #ham_results
- for _, result in pairs(ham_results) do
- table.insert(results, result)
- end
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
end
+ end
- if spam_directory then
- io.write("Scanning spam corpus...\n")
- local spam_results = scan_email(connections, spam_directory, res.timeout)
- spam_results = scan_results_to_logs(spam_results, SPAM)
+ if spam_directory then
+ rspamd_logger.messagex("Scanning spam corpus...")
+ local spam_results = scan_email(connections, spam_directory, res.timeout)
+ spam_results = scan_results_to_logs(spam_results, SPAM)
- no_of_spam = #spam_results
+ no_of_spam = #spam_results
- for _, result in pairs(spam_results) do
- table.insert(results, result)
- end
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
end
-
- io.write(string.format("Writing results to %s\n", output))
- write_results(results, output)
-
- io.write("\nStats: \n")
- local elapsed_time = os.time() - start_time
- local total_msgs = no_of_ham + no_of_spam
- io.write(string.format("Elapsed time: %ds\n", os.time() - start_time))
- io.write(string.format("No of ham: %d\n", no_of_ham))
- io.write(string.format("No of spam: %d\n", no_of_spam))
- io.write(string.format("Messages/sec: %-.2f\n", (total_msgs/elapsed_time)))
+ end
+
+ rspamd_logger.messagex("Writing results to %s", output)
+ write_results(results, output)
+
+ rspamd_logger.messagex("Stats: ")
+ local elapsed_time = os.time() - start_time
+ local total_msgs = no_of_ham + no_of_spam
+ rspamd_logger.messagex("Elapsed time: %ss", elapsed_time)
+ rspamd_logger.messagex("No of ham: %s", no_of_ham)
+ rspamd_logger.messagex("No of spam: %s", no_of_spam)
+ rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time))
end