You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clickhouse.lua 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. --[[
  2. Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. local argparse = require "argparse"
  14. local lua_clickhouse = require "lua_clickhouse"
  15. local lua_util = require "lua_util"
  16. local rspamd_http = require "rspamd_http"
  17. local rspamd_upstream_list = require "rspamd_upstream_list"
  18. local rspamd_logger = require "rspamd_logger"
  19. local ucl = require "ucl"
  20. local E = {}
  21. -- Define command line options
  22. local parser = argparse()
  23. :name 'rspamadm clickhouse'
  24. :description 'Retrieve information from Clickhouse'
  25. :help_description_margin(30)
  26. :command_target('command')
  27. :require_command(true)
  28. parser:option '-c --config'
  29. :description 'Path to config file'
  30. :argname('config_file')
  31. :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf')
  32. parser:option '-d --database'
  33. :description 'Name of Clickhouse database to use'
  34. :argname('database')
  35. :default('default')
  36. parser:flag '--no-ssl-verify'
  37. :description 'Disable SSL verification'
  38. :argname('no_ssl_verify')
  39. parser:mutex(
  40. parser:option '-p --password'
  41. :description 'Password to use for Clickhouse'
  42. :argname('password'),
  43. parser:flag '-a --ask-password'
  44. :description 'Ask password from the terminal'
  45. :argname('ask_password')
  46. )
  47. parser:option '-s --server'
  48. :description 'Address[:port] to connect to Clickhouse with'
  49. :argname('server')
  50. parser:option '-u --user'
  51. :description 'Username to use for Clickhouse'
  52. :argname('user')
  53. parser:option '--use-gzip'
  54. :description 'Use Gzip with Clickhouse'
  55. :argname('use_gzip')
  56. :default(true)
  57. parser:flag '--use-https'
  58. :description 'Use HTTPS with Clickhouse'
  59. :argname('use_https')
  60. local neural_profile = parser:command 'neural_profile'
  61. :description 'Generate symbols profile using data from Clickhouse'
  62. neural_profile:option '-w --where'
  63. :description 'WHERE clause for Clickhouse query'
  64. :argname('where')
  65. neural_profile:flag '-j --json'
  66. :description 'Write output as JSON'
  67. :argname('json')
  68. neural_profile:option '--days'
  69. :description 'Number of days to collect stats for'
  70. :argname('days')
  71. :default('7')
  72. neural_profile:option '--limit -l'
  73. :description 'Maximum rows to fetch per day'
  74. :argname('limit')
  75. neural_profile:option '--settings-id'
  76. :description 'Settings ID to query'
  77. :argname('settings_id')
  78. :default('')
  79. local neural_train = parser:command 'neural_train'
  80. :description 'Train neural using data from Clickhouse'
  81. neural_train:option '--days'
  82. :description 'Number of days to query data for'
  83. :argname('days')
  84. :default('7')
  85. neural_train:option '--column-name-digest'
  86. :description 'Name of neural profile digest column in Clickhouse'
  87. :argname('column_name_digest')
  88. :default('NeuralDigest')
  89. neural_train:option '--column-name-vector'
  90. :description 'Name of neural training vector column in Clickhouse'
  91. :argname('column_name_vector')
  92. :default('NeuralMpack')
  93. neural_train:option '--limit -l'
  94. :description 'Maximum rows to fetch per day'
  95. :argname('limit')
  96. neural_train:option '--profile -p'
  97. :description 'Profile to use for training'
  98. :argname('profile')
  99. :default('default')
  100. neural_train:option '--rule -r'
  101. :description 'Rule to train'
  102. :argname('rule')
  103. :default('default')
  104. neural_train:option '--spam -s'
  105. :description 'WHERE clause to use for spam'
  106. :argname('spam')
  107. :default("Action == 'reject'")
  108. neural_train:option '--ham -h'
  109. :description 'WHERE clause to use for ham'
  110. :argname('ham')
  111. :default('Score < 0')
  112. neural_train:option '--url -u'
  113. :description 'URL to use for training'
  114. :argname('url')
  115. :default('http://127.0.0.1:11334/plugins/neural/learn')
  116. local http_params = {
  117. config = rspamd_config,
  118. ev_base = rspamadm_ev_base,
  119. session = rspamadm_session,
  120. resolver = rspamadm_dns_resolver,
  121. }
  122. local function load_config(config_file)
  123. local _r,err = rspamd_config:load_ucl(config_file)
  124. if not _r then
  125. rspamd_logger.errx('cannot load %s: %s', config_file, err)
  126. os.exit(1)
  127. end
  128. _r,err = rspamd_config:parse_rcl({'logging', 'worker'})
  129. if not _r then
  130. rspamd_logger.errx('cannot process %s: %s', config_file, err)
  131. os.exit(1)
  132. end
  133. if not rspamd_config:init_modules() then
  134. rspamd_logger.errx('cannot init modules when parsing %s', config_file)
  135. os.exit(1)
  136. end
  137. rspamd_config:init_subsystem('symcache')
  138. end
  139. local function days_list(days)
  140. -- Create list of days to query starting with yesterday
  141. local query_days = {}
  142. local previous_date = os.time() - 86400
  143. local num_days = tonumber(days)
  144. for _ = 1, num_days do
  145. table.insert(query_days, os.date('%Y-%m-%d', previous_date))
  146. previous_date = previous_date - 86400
  147. end
  148. return query_days
  149. end
  150. local function get_excluded_symbols(known_symbols, correlations, seen_total)
  151. -- Walk results once to collect all symbols & count occurrences
  152. local remove = {}
  153. local known_symbols_list = {}
  154. local composites = rspamd_config:get_all_opt('composites')
  155. local all_symbols = rspamd_config:get_symbols()
  156. local skip_flags = {
  157. nostat = true,
  158. skip = true,
  159. idempotent = true,
  160. composite = true,
  161. }
  162. for k, v in pairs(known_symbols) do
  163. local lower_count, higher_count
  164. if v.seen_spam > v.seen_ham then
  165. lower_count = v.seen_ham
  166. higher_count = v.seen_spam
  167. else
  168. lower_count = v.seen_spam
  169. higher_count = v.seen_ham
  170. end
  171. if composites[k] then
  172. remove[k] = 'composite symbol'
  173. elseif lower_count / higher_count >= 0.95 then
  174. remove[k] = 'weak ham/spam correlation'
  175. elseif v.seen / seen_total >= 0.9 then
  176. remove[k] = 'omnipresent symbol'
  177. elseif not all_symbols[k] then
  178. remove[k] = 'nonexistent symbol'
  179. else
  180. for fl,_ in pairs(all_symbols[k].flags or {}) do
  181. if skip_flags[fl] then
  182. remove[k] = fl .. ' symbol'
  183. break
  184. end
  185. end
  186. end
  187. known_symbols_list[v.id] = {
  188. seen = v.seen,
  189. name = k,
  190. }
  191. end
  192. -- Walk correlation matrix and check total counts
  193. for sym_id, row in pairs(correlations) do
  194. for inner_sym_id, count in pairs(row) do
  195. local known = known_symbols_list[sym_id]
  196. local inner = known_symbols_list[inner_sym_id]
  197. if known and count == known.seen and not remove[inner.name] and not remove[known.name] then
  198. remove[known.name] = string.format("overlapped by %s",
  199. known_symbols_list[inner_sym_id].name)
  200. end
  201. end
  202. end
  203. return remove
  204. end
  205. local function handle_neural_profile(args)
  206. local known_symbols, correlations = {}, {}
  207. local symbols_count, seen_total = 0, 0
  208. local function process_row(r)
  209. local is_spam = true
  210. if r['Action'] == 'no action' or r['Action'] == 'greylist' then
  211. is_spam = false
  212. end
  213. seen_total = seen_total + 1
  214. local nsym = #r['Symbols.Names']
  215. for i = 1,nsym do
  216. local sym = r['Symbols.Names'][i]
  217. local t = known_symbols[sym]
  218. if not t then
  219. local spam_count, ham_count = 0, 0
  220. if is_spam then
  221. spam_count = spam_count + 1
  222. else
  223. ham_count = ham_count + 1
  224. end
  225. known_symbols[sym] = {
  226. id = symbols_count,
  227. seen = 1,
  228. seen_ham = ham_count,
  229. seen_spam = spam_count,
  230. }
  231. symbols_count = symbols_count + 1
  232. else
  233. known_symbols[sym].seen = known_symbols[sym].seen + 1
  234. if is_spam then
  235. known_symbols[sym].seen_spam = known_symbols[sym].seen_spam + 1
  236. else
  237. known_symbols[sym].seen_ham = known_symbols[sym].seen_ham + 1
  238. end
  239. end
  240. end
  241. -- Fill correlations
  242. for i = 1,nsym do
  243. for j = 1,nsym do
  244. if i ~= j then
  245. local sym = r['Symbols.Names'][i]
  246. local inner_sym_name = r['Symbols.Names'][j]
  247. local known_sym = known_symbols[sym]
  248. local inner_sym = known_symbols[inner_sym_name]
  249. if known_sym and inner_sym then
  250. if not correlations[known_sym.id] then
  251. correlations[known_sym.id] = {}
  252. end
  253. local n = correlations[known_sym.id][inner_sym.id] or 0
  254. n = n + 1
  255. correlations[known_sym.id][inner_sym.id] = n
  256. end
  257. end
  258. end
  259. end
  260. end
  261. local query_days = days_list(args.days)
  262. local conditions = {}
  263. table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id))
  264. local limit = ''
  265. local num_limit = tonumber(args.limit)
  266. if num_limit then
  267. limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
  268. end
  269. if args.where then
  270. table.insert(conditions, args.where)
  271. end
  272. local query_fmt = 'SELECT Action, Symbols.Names FROM rspamd WHERE %s%s'
  273. for _, query_day in ipairs(query_days) do
  274. -- Date should be the last condition
  275. table.insert(conditions, string.format("Date = '%s'", query_day))
  276. local query = string.format(query_fmt, table.concat(conditions, ' AND '), limit)
  277. local upstream = args.upstream:get_upstream_round_robin()
  278. local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
  279. if err ~= nil then
  280. io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
  281. os.exit(1)
  282. end
  283. conditions[#conditions] = nil -- remove Date condition
  284. end
  285. local remove = get_excluded_symbols(known_symbols, correlations, seen_total)
  286. if not args.json then
  287. for k in pairs(known_symbols) do
  288. if not remove[k] then
  289. io.stdout:write(string.format('%s\n', k))
  290. end
  291. end
  292. os.exit(0)
  293. end
  294. local json_output = {
  295. all_symbols = {},
  296. removed_symbols = {},
  297. used_symbols = {},
  298. }
  299. for k in pairs(known_symbols) do
  300. table.insert(json_output.all_symbols, k)
  301. local why_removed = remove[k]
  302. if why_removed then
  303. json_output.removed_symbols[k] = why_removed
  304. else
  305. table.insert(json_output.used_symbols, k)
  306. end
  307. end
  308. io.stdout:write(ucl.to_format(json_output, 'json'))
  309. end
  310. local function post_neural_training(url, rule, spam_rows, ham_rows)
  311. -- Prepare JSON payload
  312. local payload = ucl.to_format(
  313. {
  314. ham_vec = ham_rows,
  315. rule = rule,
  316. spam_vec = spam_rows,
  317. }, 'json')
  318. -- POST the payload
  319. local err, response = rspamd_http.request({
  320. body = payload,
  321. config = rspamd_config,
  322. ev_base = rspamadm_ev_base,
  323. log_obj = rspamd_config,
  324. resolver = rspamadm_dns_resolver,
  325. session = rspamadm_session,
  326. url = url,
  327. })
  328. if err then
  329. io.stderr:write(string.format('HTTP error: %s\n', err))
  330. os.exit(1)
  331. end
  332. if response.code ~= 200 then
  333. io.stderr:write(string.format('bad HTTP code: %d\n', response.code))
  334. os.exit(1)
  335. end
  336. io.stdout:write(string.format('%s\n', response.content))
  337. end
  338. local function handle_neural_train(args)
  339. local this_where -- which class of messages are we collecting data for
  340. local ham_rows, spam_rows = {}, {}
  341. local want_spam, want_ham = true, true -- keep collecting while true
  342. -- Try find profile in config
  343. local neural_opts = rspamd_config:get_all_opt('neural')
  344. local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile]
  345. if not symbols_profile then
  346. io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule))
  347. os.exit(1)
  348. end
  349. -- Try find max_trains
  350. local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000
  351. -- Callback used to process rows from Clickhouse
  352. local function process_row(r)
  353. local destination -- which table to collect this information in
  354. if this_where == args.ham then
  355. destination = ham_rows
  356. if #destination >= max_trains then
  357. want_ham = false
  358. return
  359. end
  360. else
  361. destination = spam_rows
  362. if #destination >= max_trains then
  363. want_spam = false
  364. return
  365. end
  366. end
  367. local ucl_parser = ucl.parser()
  368. local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack')
  369. if not ok then
  370. io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err))
  371. os.exit(1)
  372. end
  373. table.insert(destination, ucl_parser:get_object())
  374. end
  375. -- Generate symbols digest
  376. table.sort(symbols_profile)
  377. local symbols_digest = lua_util.table_digest(symbols_profile)
  378. -- Create list of days to query data for
  379. local query_days = days_list(args.days)
  380. -- Set value for limit
  381. local limit = ''
  382. local num_limit = tonumber(args.limit)
  383. if num_limit then
  384. limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
  385. end
  386. -- Prepare query elements
  387. local conditions = {string.format("%s = '%s'", args.column_name_digest, symbols_digest)}
  388. local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s'
  389. -- Run queries
  390. for _, the_where in ipairs({args.ham, args.spam}) do
  391. -- Inform callback which group of vectors we're collecting
  392. this_where = the_where
  393. table.insert(conditions, the_where) -- should be 2nd from last condition
  394. -- Loop over days and try collect data
  395. for _, query_day in ipairs(query_days) do
  396. -- Break the loop if we have enough data already
  397. if this_where == args.ham then
  398. if not want_ham then
  399. break
  400. end
  401. else
  402. if not want_spam then
  403. break
  404. end
  405. end
  406. -- Date should be the last condition
  407. table.insert(conditions, string.format("Date = '%s'", query_day))
  408. local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit)
  409. local upstream = args.upstream:get_upstream_round_robin()
  410. local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
  411. if err ~= nil then
  412. io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
  413. os.exit(1)
  414. end
  415. conditions[#conditions] = nil -- remove Date condition
  416. end
  417. conditions[#conditions] = nil -- remove spam/ham condition
  418. end
  419. -- Make sure we collected enough data for training
  420. if #ham_rows < max_trains then
  421. io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains))
  422. os.exit(1)
  423. end
  424. if #spam_rows < max_trains then
  425. io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains))
  426. os.exit(1)
  427. end
  428. return post_neural_training(args.url, args.rule, spam_rows, ham_rows)
  429. end
  430. local command_handlers = {
  431. neural_profile = handle_neural_profile,
  432. neural_train = handle_neural_train,
  433. }
  434. local function handler(args)
  435. local cmd_opts = parser:parse(args)
  436. load_config(cmd_opts.config_file)
  437. local cfg_opts = rspamd_config:get_all_opt('clickhouse')
  438. if cmd_opts.ask_password then
  439. local rspamd_util = require "rspamd_util"
  440. io.write('Password: ')
  441. cmd_opts.password = rspamd_util.readpassphrase()
  442. end
  443. local function override_settings(params)
  444. for _, which in ipairs(params) do
  445. if cmd_opts[which] == nil then
  446. cmd_opts[which] = cfg_opts[which]
  447. end
  448. end
  449. end
  450. override_settings({
  451. 'database', 'no_ssl_verify', 'password', 'server',
  452. 'use_gzip', 'use_https', 'user',
  453. })
  454. local servers = cmd_opts['server'] or cmd_opts['servers']
  455. if not servers then
  456. parser:error("server(s) unspecified & couldn't be fetched from config")
  457. end
  458. cmd_opts.upstream = rspamd_upstream_list.create(rspamd_config, servers, 8123)
  459. if not cmd_opts.upstream then
  460. io.stderr:write(string.format("can't parse clickhouse address: %s\n", servers))
  461. os.exit(1)
  462. end
  463. local f = command_handlers[cmd_opts.command]
  464. if not f then
  465. parser:error(string.format("command isn't implemented: %s",
  466. cmd_opts.command))
  467. end
  468. f(cmd_opts)
  469. end
  470. return {
  471. handler = handler,
  472. description = parser._description,
  473. name = 'clickhouse'
  474. }