You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rescore_utility.lua 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. local lua_util = require "lua_util"
  2. local rspamd_util = require "rspamd_util"
  3. local fun = require "fun"
  4. local utility = {}
  5. function utility.get_all_symbols(logs, ignore_symbols)
  6. -- Returns a list of all symbols
  7. local symbols_set = {}
  8. for _, line in pairs(logs) do
  9. line = lua_util.rspamd_str_split(line, " ")
  10. for i=4,(#line-1) do
  11. line[i] = line[i]:gsub("%s+", "")
  12. if not symbols_set[line[i]] then
  13. symbols_set[line[i]] = true
  14. end
  15. end
  16. end
  17. local all_symbols = {}
  18. for symbol, _ in pairs(symbols_set) do
  19. if not ignore_symbols[symbol] then
  20. all_symbols[#all_symbols + 1] = symbol
  21. end
  22. end
  23. table.sort(all_symbols)
  24. return all_symbols
  25. end
  26. function utility.read_log_file(file)
  27. local lines = {}
  28. local messages = {}
  29. local fd = assert(io.open(file, "r"))
  30. local fname = string.gsub(file, "(.*/)(.*)", "%2")
  31. for line in fd:lines() do
  32. local start,stop = string.find(line, fname .. ':')
  33. if start and stop then
  34. table.insert(lines, string.sub(line, 1, start))
  35. table.insert(messages, string.sub(line, stop + 1, -1))
  36. end
  37. end
  38. io.close(fd)
  39. return lines,messages
  40. end
  41. function utility.get_all_logs(dirs)
  42. -- Reads all log files in the directory and returns a list of logs.
  43. if type(dirs) == 'string' then
  44. dirs = {dirs}
  45. end
  46. local all_logs = {}
  47. local all_messages = {}
  48. for _,dir in ipairs(dirs) do
  49. if dir:sub(-1, -1) == "/" then
  50. dir = dir:sub(1, -2)
  51. local files = rspamd_util.glob(dir .. "/*.log")
  52. for _, file in pairs(files) do
  53. local logs,messages = utility.read_log_file(file)
  54. for i=1,#logs do
  55. table.insert(all_logs, logs[i])
  56. table.insert(all_messages, messages[i])
  57. end
  58. end
  59. else
  60. local logs,messages = utility.read_log_file(dir)
  61. for i=1,#logs do
  62. table.insert(all_logs, logs[i])
  63. table.insert(all_messages, messages[i])
  64. end
  65. end
  66. end
  67. return all_logs,all_messages
  68. end
  69. function utility.get_all_symbol_scores(conf, ignore_symbols)
  70. local symbols = conf:get_symbols_scores()
  71. return fun.tomap(fun.map(function(name, elt)
  72. return name,elt['score']
  73. end, fun.filter(function(name, elt)
  74. return not ignore_symbols[name]
  75. end, symbols)))
  76. end
  77. function utility.generate_statistics_from_logs(logs, messages, threshold)
  78. -- Returns file_stats table and list of symbol_stats table.
  79. local file_stats = {
  80. no_of_emails = 0,
  81. no_of_spam = 0,
  82. no_of_ham = 0,
  83. spam_percent = 0,
  84. ham_percent = 0,
  85. true_positives = 0,
  86. true_negatives = 0,
  87. false_negative_rate = 0,
  88. false_positive_rate = 0,
  89. overall_accuracy = 0,
  90. fscore = 0,
  91. avg_scan_time = 0,
  92. slowest_file = nil,
  93. slowest = 0
  94. }
  95. local all_symbols_stats = {}
  96. local all_fps = {}
  97. local all_fns = {}
  98. local false_positives = 0
  99. local false_negatives = 0
  100. local true_positives = 0
  101. local true_negatives = 0
  102. local no_of_emails = 0
  103. local no_of_spam = 0
  104. local no_of_ham = 0
  105. for i, log in ipairs(logs) do
  106. log = lua_util.rspamd_str_trim(log)
  107. log = lua_util.rspamd_str_split(log, " ")
  108. local message = messages[i]
  109. local is_spam = (log[1] == "SPAM")
  110. local score = tonumber(log[2])
  111. no_of_emails = no_of_emails + 1
  112. if is_spam then
  113. no_of_spam = no_of_spam + 1
  114. else
  115. no_of_ham = no_of_ham + 1
  116. end
  117. if is_spam and (score >= threshold) then
  118. true_positives = true_positives + 1
  119. elseif is_spam and (score < threshold) then
  120. false_negatives = false_negatives + 1
  121. table.insert(all_fns, message)
  122. elseif not is_spam and (score >= threshold) then
  123. false_positives = false_positives + 1
  124. table.insert(all_fps, message)
  125. else
  126. true_negatives = true_negatives + 1
  127. end
  128. for j=4, (#log-1) do
  129. if all_symbols_stats[log[j]] == nil then
  130. all_symbols_stats[log[j]] = {
  131. name = message,
  132. no_of_hits = 0,
  133. spam_hits = 0,
  134. ham_hits = 0,
  135. spam_overall = 0
  136. }
  137. end
  138. local sym = log[j]
  139. all_symbols_stats[sym].no_of_hits = all_symbols_stats[sym].no_of_hits + 1
  140. if is_spam then
  141. all_symbols_stats[sym].spam_hits = all_symbols_stats[sym].spam_hits + 1
  142. else
  143. all_symbols_stats[sym].ham_hits = all_symbols_stats[sym].ham_hits + 1
  144. end
  145. -- Find slowest message
  146. if ((tonumber(log[#log]) or 0) > file_stats.slowest) then
  147. file_stats.slowest = tonumber(log[#log])
  148. file_stats.slowest_file = message
  149. end
  150. end
  151. end
  152. -- Calculating file stats
  153. file_stats.no_of_ham = no_of_ham
  154. file_stats.no_of_spam = no_of_spam
  155. file_stats.no_of_emails = no_of_emails
  156. file_stats.true_positives = true_positives
  157. file_stats.true_negatives = true_negatives
  158. if no_of_emails > 0 then
  159. file_stats.spam_percent = no_of_spam * 100 / no_of_emails
  160. file_stats.ham_percent = no_of_ham * 100 / no_of_emails
  161. file_stats.overall_accuracy = (true_positives + true_negatives) * 100 /
  162. no_of_emails
  163. end
  164. if no_of_ham > 0 then
  165. file_stats.false_positive_rate = false_positives * 100 / no_of_ham
  166. end
  167. if no_of_spam > 0 then
  168. file_stats.false_negative_rate = false_negatives * 100 / no_of_spam
  169. end
  170. file_stats.fscore = 2 * true_positives / (2
  171. * true_positives
  172. + false_positives
  173. + false_negatives)
  174. -- Calculating symbol stats
  175. for _, symbol_stats in pairs(all_symbols_stats) do
  176. symbol_stats.spam_percent = symbol_stats.spam_hits * 100 / no_of_spam
  177. symbol_stats.ham_percent = symbol_stats.ham_hits * 100 / no_of_ham
  178. symbol_stats.overall = symbol_stats.no_of_hits * 100 / no_of_emails
  179. symbol_stats.spam_overall = symbol_stats.spam_percent /
  180. (symbol_stats.spam_percent + symbol_stats.ham_percent)
  181. end
  182. return file_stats, all_symbols_stats, all_fps, all_fns
  183. end
  184. return utility