You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

corpus_test.lua 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. local rspamd_logger = require "rspamd_logger"
  2. local ucl = require "ucl"
  3. local lua_util = require "lua_util"
  4. local argparse = require "argparse"
  5. local parser = argparse()
  6. :name "rspamadm corpus_test"
  7. :description "Create logs files from email corpus"
  8. :help_description_margin(32)
  9. parser:option "-H --ham"
  10. :description("Ham directory")
  11. :argname("<dir>")
  12. parser:option "-S --spam"
  13. :description("Spam directory")
  14. :argname("<dir>")
  15. parser:option "-n --conns"
  16. :description("Number of parallel connections")
  17. :argname("<N>")
  18. :convert(tonumber)
  19. :default(10)
  20. parser:option "-o --output"
  21. :description("Output file")
  22. :argname("<file>")
  23. :default('results.log')
  24. parser:option "-t --timeout"
  25. :description("Timeout for client connections")
  26. :argname("<sec>")
  27. :convert(tonumber)
  28. :default(60)
  29. parser:option "-c --connect"
  30. :description("Connect to specific host")
  31. :argname("<host>")
  32. :default('localhost:11334')
  33. parser:option "-r --rspamc"
  34. :description("Use specific rspamc path")
  35. :argname("<path>")
  36. :default('rspamc')
  37. local HAM = "HAM"
  38. local SPAM = "SPAM"
  39. local opts
  40. local function scan_email(n_parallel, path, timeout)
  41. local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s",
  42. opts.rspamc, opts.connect, n_parallel, timeout, path)
  43. local result = assert(io.popen(rspamc_command))
  44. result = result:read("*all")
  45. return result
  46. end
  47. local function write_results(results, file)
  48. local f = io.open(file, 'w')
  49. for _, result in pairs(results) do
  50. local log_line = string.format("%s %.2f %s",
  51. result.type, result.score, result.action)
  52. for _, sym in pairs(result.symbols) do
  53. log_line = log_line .. " " .. sym
  54. end
  55. log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename
  56. log_line = log_line .. "\r\n"
  57. f:write(log_line)
  58. end
  59. f:close()
  60. end
  61. local function encoded_json_to_log(result)
  62. -- Returns table containing score, action, list of symbols
  63. local filtered_result = {}
  64. local ucl_parser = ucl.parser()
  65. local is_good, err = ucl_parser:parse_string(result)
  66. if not is_good then
  67. rspamd_logger.errx("Parser error: %1", err)
  68. return nil
  69. end
  70. result = ucl_parser:get_object()
  71. filtered_result.score = result.score
  72. if not result.action then
  73. rspamd_logger.errx("Bad JSON: %1", result)
  74. return nil
  75. end
  76. local action = result.action:gsub("%s+", "_")
  77. filtered_result.action = action
  78. filtered_result.symbols = {}
  79. for sym, _ in pairs(result.symbols) do
  80. table.insert(filtered_result.symbols, sym)
  81. end
  82. filtered_result.filename = result.filename
  83. filtered_result.scan_time = result.scan_time
  84. return filtered_result
  85. end
  86. local function scan_results_to_logs(results, actual_email_type)
  87. local logs = {}
  88. results = lua_util.rspamd_str_split(results, "\n")
  89. if results[#results] == "" then
  90. results[#results] = nil
  91. end
  92. for _, result in pairs(results) do
  93. result = encoded_json_to_log(result)
  94. if result then
  95. result['type'] = actual_email_type
  96. table.insert(logs, result)
  97. end
  98. end
  99. return logs
  100. end
  101. local function handler(args)
  102. opts = parser:parse(args)
  103. local ham_directory = opts['ham']
  104. local spam_directory = opts['spam']
  105. local connections = opts["conns"]
  106. local output = opts["output"]
  107. local results = {}
  108. local start_time = os.time()
  109. local no_of_ham = 0
  110. local no_of_spam = 0
  111. if ham_directory then
  112. rspamd_logger.messagex("Scanning ham corpus...")
  113. local ham_results = scan_email(connections, ham_directory, opts["timeout"])
  114. ham_results = scan_results_to_logs(ham_results, HAM)
  115. no_of_ham = #ham_results
  116. for _, result in pairs(ham_results) do
  117. table.insert(results, result)
  118. end
  119. end
  120. if spam_directory then
  121. rspamd_logger.messagex("Scanning spam corpus...")
  122. local spam_results = scan_email(connections, spam_directory, opts.timeout)
  123. spam_results = scan_results_to_logs(spam_results, SPAM)
  124. no_of_spam = #spam_results
  125. for _, result in pairs(spam_results) do
  126. table.insert(results, result)
  127. end
  128. end
  129. rspamd_logger.messagex("Writing results to %s", output)
  130. write_results(results, output)
  131. rspamd_logger.messagex("Stats: ")
  132. local elapsed_time = os.time() - start_time
  133. local total_msgs = no_of_ham + no_of_spam
  134. rspamd_logger.messagex("Elapsed time: %ss", elapsed_time)
  135. rspamd_logger.messagex("No of ham: %s", no_of_ham)
  136. rspamd_logger.messagex("No of spam: %s", no_of_spam)
  137. rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time))
  138. end
  139. return {
  140. name = 'corpustest',
  141. aliases = {'corpus_test', 'corpus'},
  142. handler = handler,
  143. description = parser._description
  144. }