Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. --[[
  2. Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[
  14. local lua_util = require "lua_util"
  15. local ucl = require "ucl"
  16. local logger = require "rspamd_logger"
  17. local rspamd_util = require "rspamd_util"
  18. local argparse = require "argparse"
  19. local rescore_utility = require "rescore_utility"
  20. local opts
  21. local ignore_symbols = {
  22. ['DATE_IN_PAST'] =true,
  23. ['DATE_IN_FUTURE'] = true,
  24. }
  25. local parser = argparse()
  26. :name "rspamadm rescore"
  27. :description "Estimate optimal symbol weights from log files"
  28. :help_description_margin(37)
  29. parser:option "-l --log"
  30. :description "Log file or files (from rescore)"
  31. :argname("<log>")
  32. :args "*"
  33. parser:option "-c --config"
  34. :description "Path to config file"
  35. :argname("<file>")
  36. :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
  37. parser:option "-o --output"
  38. :description "Output file"
  39. :argname("<file>")
  40. :default("new.scores")
  41. parser:flag "-d --diff"
  42. :description "Show differences in scores"
  43. parser:flag "-v --verbose"
  44. :description "Verbose output"
  45. parser:flag "-z --freq"
  46. :description "Display hit frequencies"
  47. parser:option "-i --iters"
  48. :description "Learn iterations"
  49. :argname("<n>")
  50. :convert(tonumber)
  51. :default(10)
  52. parser:option "-b --batch"
  53. :description "Batch size"
  54. :argname("<n>")
  55. :convert(tonumber)
  56. :default(100)
  57. parser:option "-d --decay"
  58. :description "Decay rate"
  59. :argname("<n>")
  60. :convert(tonumber)
  61. :default(0.001)
  62. parser:option "-m --momentum"
  63. :description "Learn momentum"
  64. :argname("<n>")
  65. :convert(tonumber)
  66. :default(0.1)
  67. parser:option "-t --threads"
  68. :description "Number of threads to use"
  69. :argname("<n>")
  70. :convert(tonumber)
  71. :default(1)
  72. parser:option "-o --optim"
  73. :description "Optimisation algorithm"
  74. :argname("<alg>")
  75. :convert {
  76. LBFGS = "LBFGS",
  77. ADAM = "ADAM",
  78. ADAGRAD = "ADAGRAD",
  79. SGD = "SGD",
  80. NAG = "NAG"
  81. }
  82. :default "ADAM"
  83. parser:option "--ignore-symbol"
  84. :description "Ignore symbol from logs"
  85. :argname("<sym>")
  86. :args "*"
  87. parser:option "--penalty-weight"
  88. :description "Add new penalty weight to test"
  89. :argname("<n>")
  90. :convert(tonumber)
  91. :args "*"
  92. parser:option "--learning-rate"
  93. :description "Add new learning rate to test"
  94. :argname("<n>")
  95. :convert(tonumber)
  96. :args "*"
  97. parser:option "--spam_action"
  98. :description "Spam action"
  99. :argname("<act>")
  100. :default("reject")
  101. parser:option "--learning_rate_decay"
  102. :description "Learn rate decay (for some algs)"
  103. :argname("<n>")
  104. :convert(tonumber)
  105. :default(0.0)
  106. parser:option "--weight_decay"
  107. :description "Weight decay (for some algs)"
  108. :argname("<n>")
  109. :convert(tonumber)
  110. :default(0.0)
  111. parser:option "--l1"
  112. :description "L1 regularization penalty"
  113. :argname("<n>")
  114. :convert(tonumber)
  115. :default(0.0)
  116. parser:option "--l2"
  117. :description "L2 regularization penalty"
  118. :argname("<n>")
  119. :convert(tonumber)
  120. :default(0.0)
  121. local function make_dataset_from_logs(logs, all_symbols, spam_score)
  122. local inputs = {}
  123. local outputs = {}
  124. for _, log in pairs(logs) do
  125. log = lua_util.rspamd_str_split(log, " ")
  126. if log[1] == "SPAM" then
  127. outputs[#outputs+1] = 1
  128. else
  129. outputs[#outputs+1] = 0
  130. end
  131. local symbols_set = {}
  132. for i=4,#log do
  133. if not ignore_symbols[ log[i] ] then
  134. symbols_set[log[i] ] = true
  135. end
  136. end
  137. local input_vec = {}
  138. for index, symbol in pairs(all_symbols) do
  139. if symbols_set[symbol] then
  140. input_vec[index] = 1
  141. else
  142. input_vec[index] = 0
  143. end
  144. end
  145. inputs[#inputs + 1] = input_vec
  146. end
  147. return inputs,outputs
  148. end
  149. local function init_weights(all_symbols, original_symbol_scores)
  150. end
  151. local function shuffle(logs, messages)
  152. local size = #logs
  153. for i = size, 1, -1 do
  154. local rand = math.random(size)
  155. logs[i], logs[rand] = logs[rand], logs[i]
  156. messages[i], messages[rand] = messages[rand], messages[i]
  157. end
  158. end
  159. local function split_logs(logs, messages, split_percent)
  160. if not split_percent then
  161. split_percent = 60
  162. end
  163. local split_index = math.floor(#logs * split_percent / 100)
  164. local test_logs = {}
  165. local train_logs = {}
  166. local test_messages = {}
  167. local train_messages = {}
  168. for i=1,split_index do
  169. table.insert(train_logs, logs[i])
  170. table.insert(train_messages, messages[i])
  171. end
  172. for i=split_index + 1, #logs do
  173. table.insert(test_logs, logs[i])
  174. table.insert(test_messages, messages[i])
  175. end
  176. return {train_logs,train_messages}, {test_logs,test_messages}
  177. end
  178. local function stitch_new_scores(all_symbols, new_scores)
  179. local new_symbol_scores = {}
  180. for idx, symbol in pairs(all_symbols) do
  181. new_symbol_scores[symbol] = new_scores[idx]
  182. end
  183. return new_symbol_scores
  184. end
  185. local function update_logs(logs, symbol_scores)
  186. for i, log in ipairs(logs) do
  187. log = lua_util.rspamd_str_split(log, " ")
  188. local score = 0
  189. for j=4,#log do
  190. log[j] = log[j]:gsub("%s+", "")
  191. score = score + (symbol_scores[log[j] ] or 0)
  192. end
  193. log[2] = lua_util.round(score, 2)
  194. logs[i] = table.concat(log, " ")
  195. end
  196. return logs
  197. end
  198. local function write_scores(new_symbol_scores, file_path)
  199. local file = assert(io.open(file_path, "w"))
  200. local new_scores_ucl = ucl.to_format(new_symbol_scores, "ucl")
  201. file:write(new_scores_ucl)
  202. file:close()
  203. end
  204. local function print_score_diff(new_symbol_scores, original_symbol_scores)
  205. logger.message(string.format("%-35s %-10s %-10s",
  206. "SYMBOL", "OLD_SCORE", "NEW_SCORE"))
  207. for symbol, new_score in pairs(new_symbol_scores) do
  208. logger.message(string.format("%-35s %-10s %-10s",
  209. symbol,
  210. original_symbol_scores[symbol] or 0,
  211. lua_util.round(new_score, 2)))
  212. end
  213. logger.message("\nClass changes \n")
  214. for symbol, new_score in pairs(new_symbol_scores) do
  215. if original_symbol_scores[symbol] ~= nil then
  216. if (original_symbol_scores[symbol] > 0 and new_score < 0) or
  217. (original_symbol_scores[symbol] < 0 and new_score > 0) then
  218. logger.message(string.format("%-35s %-10s %-10s",
  219. symbol,
  220. original_symbol_scores[symbol] or 0,
  221. lua_util.round(new_score, 2)))
  222. end
  223. end
  224. end
  225. end
  226. local function calculate_fscore_from_weights(logs, messages,
  227. all_symbols,
  228. weights,
  229. threshold)
  230. local new_symbol_scores = weights:clone()
  231. new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
  232. logs = update_logs(logs, new_symbol_scores)
  233. local file_stats, _, all_fps, all_fns =
  234. rescore_utility.generate_statistics_from_logs(logs, messages, threshold)
  235. return file_stats.fscore, all_fps, all_fns
  236. end
  237. local function print_stats(logs, messages, threshold)
  238. local file_stats, _ = rescore_utility.generate_statistics_from_logs(logs,
  239. messages, threshold)
  240. local file_stat_format = [=[
  241. F-score: %.2f
  242. False positive rate: %.2f %%
  243. False negative rate: %.2f %%
  244. Overall accuracy: %.2f %%
  245. Slowest message: %.2f (%s)
  246. ]=]
  247. logger.message("\nStatistics at threshold: " .. threshold)
  248. logger.message(string.format(file_stat_format,
  249. file_stats.fscore,
  250. file_stats.false_positive_rate,
  251. file_stats.false_negative_rate,
  252. file_stats.overall_accuracy,
  253. file_stats.slowest,
  254. file_stats.slowest_file))
  255. end
  256. -- training function
  257. local function train(dataset, opt, model, criterion, epoch,
  258. all_symbols, spam_threshold, initial_weights)
  259. end
  260. local learning_rates = {
  261. 0.01
  262. }
  263. local penalty_weights = {
  264. 0
  265. }
  266. local function get_threshold()
  267. local actions = rspamd_config:get_all_actions()
  268. if opts['spam-action'] then
  269. return (actions[opts['spam-action'] ] or 0),actions['reject']
  270. end
  271. return (actions['add header'] or actions['rewrite subject']
  272. or actions['reject']), actions['reject']
  273. end
  274. local function handler(args)
  275. opts = parser:parse(args)
  276. if not opts['log'] then
  277. parser:error('no log specified')
  278. end
  279. local _r,err = rspamd_config:load_ucl(opts['config'])
  280. if not _r then
  281. logger.errx('cannot parse %s: %s', opts['config'], err)
  282. os.exit(1)
  283. end
  284. _r,err = rspamd_config:parse_rcl({'logging', 'worker'})
  285. if not _r then
  286. logger.errx('cannot process %s: %s', opts['config'], err)
  287. os.exit(1)
  288. end
  289. local threshold,reject_score = get_threshold()
  290. local logs,messages = rescore_utility.get_all_logs(opts['log'])
  291. if opts['ignore-symbol'] then
  292. local function add_ignore(s)
  293. ignore_symbols[s] = true
  294. end
  295. if type(opts['ignore-symbol']) == 'table' then
  296. for _,s in ipairs(opts['ignore-symbol']) do
  297. add_ignore(s)
  298. end
  299. else
  300. add_ignore(opts['ignore-symbol'])
  301. end
  302. end
  303. if opts['learning-rate'] then
  304. learning_rates = {}
  305. local function add_rate(r)
  306. if tonumber(r) then
  307. table.insert(learning_rates, tonumber(r))
  308. end
  309. end
  310. if type(opts['learning-rate']) == 'table' then
  311. for _,s in ipairs(opts['learning-rate']) do
  312. add_rate(s)
  313. end
  314. else
  315. add_rate(opts['learning-rate'])
  316. end
  317. end
  318. if opts['penalty-weight'] then
  319. penalty_weights = {}
  320. local function add_weight(r)
  321. if tonumber(r) then
  322. table.insert(penalty_weights, tonumber(r))
  323. end
  324. end
  325. if type(opts['penalty-weight']) == 'table' then
  326. for _,s in ipairs(opts['penalty-weight']) do
  327. add_weight(s)
  328. end
  329. else
  330. add_weight(opts['penalty-weight'])
  331. end
  332. end
  333. local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
  334. local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
  335. ignore_symbols)
  336. -- Display hit frequencies
  337. if opts['freq'] then
  338. local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs,
  339. messages,
  340. threshold)
  341. local t = {}
  342. for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
  343. local function compare_symbols(a, b)
  344. if (a.spam_overall ~= b.spam_overall) then
  345. return b.spam_overall < a.spam_overall
  346. end
  347. if (b.spam_hits ~= a.spam_hits) then
  348. return b.spam_hits < a.spam_hits
  349. end
  350. return b.ham_hits < a.ham_hits
  351. end
  352. table.sort(t, compare_symbols)
  353. logger.message(string.format("%-40s %6s %6s %6s %6s %6s %6s %6s",
  354. "NAME", "HITS", "HAM", "HAM%", "SPAM", "SPAM%", "S/O", "OVER%"))
  355. for _, symbol_stats in pairs(t) do
  356. logger.message(
  357. string.format("%-40s %6d %6d %6.2f %6d %6.2f %6.2f %6.2f",
  358. symbol_stats.name,
  359. symbol_stats.no_of_hits,
  360. symbol_stats.ham_hits,
  361. lua_util.round(symbol_stats.ham_percent,2),
  362. symbol_stats.spam_hits,
  363. lua_util.round(symbol_stats.spam_percent,2),
  364. lua_util.round(symbol_stats.spam_overall,2),
  365. lua_util.round(symbol_stats.overall, 2)
  366. )
  367. )
  368. end
  369. -- Print file statistics
  370. print_stats(logs, messages, threshold)
  371. -- Work out how many symbols weren't seen in the corpus
  372. local symbols_no_hits = {}
  373. local total_symbols = 0
  374. for sym in pairs(original_symbol_scores) do
  375. total_symbols = total_symbols + 1
  376. if (all_symbols_stats[sym] == nil) then
  377. table.insert(symbols_no_hits, sym)
  378. end
  379. end
  380. if (#symbols_no_hits > 0) then
  381. table.sort(symbols_no_hits)
  382. -- Calculate percentage of rules with no hits
  383. local nhpct = lua_util.round((#symbols_no_hits/total_symbols)*100,2)
  384. logger.message(
  385. string.format('\nFound %s (%-.2f%%) symbols out of %s with no hits in corpus:',
  386. #symbols_no_hits, nhpct, total_symbols
  387. )
  388. )
  389. for _, symbol in pairs(symbols_no_hits) do
  390. logger.messagex('%s', symbol)
  391. end
  392. end
  393. return
  394. end
  395. shuffle(logs, messages)
  396. local train_logs, validation_logs = split_logs(logs, messages,70)
  397. local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50)
  398. local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score)
  399. -- Start of perceptron training
  400. local input_size = #all_symbols
  401. local linear_module = nn.Linear(input_size, 1, false)
  402. local activation = nn.Sigmoid()
  403. local perceptron = nn.Sequential()
  404. perceptron:add(linear_module)
  405. perceptron:add(activation)
  406. local criterion = nn.MSECriterion()
  407. --criterion.sizeAverage = false
  408. local best_fscore = -math.huge
  409. local best_weights = linear_module.weight[1]:clone()
  410. local best_learning_rate
  411. local best_weight_decay
  412. local all_fps
  413. local all_fns
  414. for _,lr in ipairs(learning_rates) do
  415. for _,wd in ipairs(penalty_weights) do
  416. linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
  417. local initial_weights = linear_module.weight[1]:clone()
  418. opts.learning_rate = lr
  419. opts.weight_decay = wd
  420. for i=1,tonumber(opts.iters) do
  421. train(dataset, opts, perceptron, criterion, i, all_symbols, threshold,
  422. initial_weights)
  423. end
  424. local fscore, fps, fns = calculate_fscore_from_weights(cv_logs[1],
  425. cv_logs[2],
  426. all_symbols,
  427. linear_module.weight[1],
  428. threshold)
  429. logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s",
  430. fscore, lr, wd)
  431. if best_fscore < fscore then
  432. best_learning_rate = lr
  433. best_weight_decay = wd
  434. best_fscore = fscore
  435. best_weights = linear_module.weight[1]:clone()
  436. all_fps = fps
  437. all_fns = fns
  438. end
  439. end
  440. end
  441. -- End perceptron training
  442. local new_symbol_scores = best_weights
  443. new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
  444. if opts["output"] then
  445. write_scores(new_symbol_scores, opts["output"])
  446. end
  447. if opts["diff"] then
  448. print_score_diff(new_symbol_scores, original_symbol_scores)
  449. end
  450. -- Pre-rescore test stats
  451. logger.message("\n\nPre-rescore test stats\n")
  452. test_logs[1] = update_logs(test_logs[1], original_symbol_scores)
  453. print_stats(test_logs[1], test_logs[2], threshold)
  454. -- Post-rescore test stats
  455. test_logs[1] = update_logs(test_logs[1], new_symbol_scores)
  456. logger.message("\n\nPost-rescore test stats\n")
  457. print_stats(test_logs[1], test_logs[2], threshold)
  458. logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
  459. best_fscore, best_learning_rate, best_weight_decay)
  460. -- Show all FPs/FNs, useful for corpus checking and rule creation/modification
  461. if (all_fps and #all_fps > 0) then
  462. logger.message("\nFalse-Positives:")
  463. for _, fp in pairs(all_fps) do
  464. logger.messagex('%s', fp)
  465. end
  466. end
  467. if (all_fns and #all_fns > 0) then
  468. logger.message("\nFalse-Negatives:")
  469. for _, fn in pairs(all_fns) do
  470. logger.messagex('%s', fn)
  471. end
  472. end
  473. end
  474. return {
  475. handler = handler,
  476. description = parser._description,
  477. name = 'rescore'
  478. }
  479. --]]
  480. return nil