123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528 |
- --[[
- Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru>
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ]]--
-
- local argparse = require "argparse"
- local lua_clickhouse = require "lua_clickhouse"
- local lua_util = require "lua_util"
- local rspamd_http = require "rspamd_http"
- local rspamd_upstream_list = require "rspamd_upstream_list"
- local rspamd_logger = require "rspamd_logger"
- local ucl = require "ucl"
-
- local E = {}
-
- -- Define command line options
- local parser = argparse()
- :name 'rspamadm clickhouse'
- :description 'Retrieve information from Clickhouse'
- :help_description_margin(30)
- :command_target('command')
- :require_command(true)
-
- parser:option '-c --config'
- :description 'Path to config file'
- :argname('config_file')
- :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf')
- parser:option '-d --database'
- :description 'Name of Clickhouse database to use'
- :argname('database')
- :default('default')
- parser:flag '--no-ssl-verify'
- :description 'Disable SSL verification'
- :argname('no_ssl_verify')
- parser:mutex(
- parser:option '-p --password'
- :description 'Password to use for Clickhouse'
- :argname('password'),
- parser:flag '-a --ask-password'
- :description 'Ask password from the terminal'
- :argname('ask_password')
- )
- parser:option '-s --server'
- :description 'Address[:port] to connect to Clickhouse with'
- :argname('server')
- parser:option '-u --user'
- :description 'Username to use for Clickhouse'
- :argname('user')
- parser:option '--use-gzip'
- :description 'Use Gzip with Clickhouse'
- :argname('use_gzip')
- :default(true)
- parser:flag '--use-https'
- :description 'Use HTTPS with Clickhouse'
- :argname('use_https')
-
- local neural_profile = parser:command 'neural_profile'
- :description 'Generate symbols profile using data from Clickhouse'
- neural_profile:option '-w --where'
- :description 'WHERE clause for Clickhouse query'
- :argname('where')
- neural_profile:flag '-j --json'
- :description 'Write output as JSON'
- :argname('json')
- neural_profile:option '--days'
- :description 'Number of days to collect stats for'
- :argname('days')
- :default('7')
- neural_profile:option '--limit -l'
- :description 'Maximum rows to fetch per day'
- :argname('limit')
- neural_profile:option '--settings-id'
- :description 'Settings ID to query'
- :argname('settings_id')
- :default('')
-
- local neural_train = parser:command 'neural_train'
- :description 'Train neural using data from Clickhouse'
- neural_train:option '--days'
- :description 'Number of days to query data for'
- :argname('days')
- :default('7')
- neural_train:option '--column-name-digest'
- :description 'Name of neural profile digest column in Clickhouse'
- :argname('column_name_digest')
- :default('NeuralDigest')
- neural_train:option '--column-name-vector'
- :description 'Name of neural training vector column in Clickhouse'
- :argname('column_name_vector')
- :default('NeuralMpack')
- neural_train:option '--limit -l'
- :description 'Maximum rows to fetch per day'
- :argname('limit')
- neural_train:option '--profile -p'
- :description 'Profile to use for training'
- :argname('profile')
- :default('default')
- neural_train:option '--rule -r'
- :description 'Rule to train'
- :argname('rule')
- :default('default')
- neural_train:option '--spam -s'
- :description 'WHERE clause to use for spam'
- :argname('spam')
- :default("Action == 'reject'")
- neural_train:option '--ham -h'
- :description 'WHERE clause to use for ham'
- :argname('ham')
- :default('Score < 0')
- neural_train:option '--url -u'
- :description 'URL to use for training'
- :argname('url')
- :default('http://127.0.0.1:11334/plugins/neural/learn')
-
- local http_params = {
- config = rspamd_config,
- ev_base = rspamadm_ev_base,
- session = rspamadm_session,
- resolver = rspamadm_dns_resolver,
- }
-
- local function load_config(config_file)
- local _r,err = rspamd_config:load_ucl(config_file)
-
- if not _r then
- rspamd_logger.errx('cannot load %s: %s', config_file, err)
- os.exit(1)
- end
-
- _r,err = rspamd_config:parse_rcl({'logging', 'worker'})
- if not _r then
- rspamd_logger.errx('cannot process %s: %s', config_file, err)
- os.exit(1)
- end
-
- if not rspamd_config:init_modules() then
- rspamd_logger.errx('cannot init modules when parsing %s', config_file)
- os.exit(1)
- end
-
- rspamd_config:init_subsystem('symcache')
- end
-
- local function days_list(days)
- -- Create list of days to query starting with yesterday
- local query_days = {}
- local previous_date = os.time() - 86400
- local num_days = tonumber(days)
- for _ = 1, num_days do
- table.insert(query_days, os.date('%Y-%m-%d', previous_date))
- previous_date = previous_date - 86400
- end
- return query_days
- end
-
- local function get_excluded_symbols(known_symbols, correlations, seen_total)
- -- Walk results once to collect all symbols & count occurrences
-
- local remove = {}
- local known_symbols_list = {}
- local composites = rspamd_config:get_all_opt('composites')
- local all_symbols = rspamd_config:get_symbols()
- local skip_flags = {
- nostat = true,
- skip = true,
- idempotent = true,
- composite = true,
- }
- for k, v in pairs(known_symbols) do
- local lower_count, higher_count
- if v.seen_spam > v.seen_ham then
- lower_count = v.seen_ham
- higher_count = v.seen_spam
- else
- lower_count = v.seen_spam
- higher_count = v.seen_ham
- end
-
- if composites[k] then
- remove[k] = 'composite symbol'
- elseif lower_count / higher_count >= 0.95 then
- remove[k] = 'weak ham/spam correlation'
- elseif v.seen / seen_total >= 0.9 then
- remove[k] = 'omnipresent symbol'
- elseif not all_symbols[k] then
- remove[k] = 'nonexistent symbol'
- else
- for fl,_ in pairs(all_symbols[k].flags or {}) do
- if skip_flags[fl] then
- remove[k] = fl .. ' symbol'
- break
- end
- end
- end
- known_symbols_list[v.id] = {
- seen = v.seen,
- name = k,
- }
- end
-
- -- Walk correlation matrix and check total counts
- for sym_id, row in pairs(correlations) do
- for inner_sym_id, count in pairs(row) do
- local known = known_symbols_list[sym_id]
- local inner = known_symbols_list[inner_sym_id]
- if known and count == known.seen and not remove[inner.name] and not remove[known.name] then
- remove[known.name] = string.format("overlapped by %s",
- known_symbols_list[inner_sym_id].name)
- end
- end
- end
-
- return remove
- end
-
- local function handle_neural_profile(args)
-
- local known_symbols, correlations = {}, {}
- local symbols_count, seen_total = 0, 0
-
- local function process_row(r)
- local is_spam = true
- if r['Action'] == 'no action' or r['Action'] == 'greylist' then
- is_spam = false
- end
- seen_total = seen_total + 1
-
- local nsym = #r['Symbols.Names']
-
- for i = 1,nsym do
- local sym = r['Symbols.Names'][i]
- local t = known_symbols[sym]
- if not t then
- local spam_count, ham_count = 0, 0
- if is_spam then
- spam_count = spam_count + 1
- else
- ham_count = ham_count + 1
- end
- known_symbols[sym] = {
- id = symbols_count,
- seen = 1,
- seen_ham = ham_count,
- seen_spam = spam_count,
- }
- symbols_count = symbols_count + 1
- else
- known_symbols[sym].seen = known_symbols[sym].seen + 1
- if is_spam then
- known_symbols[sym].seen_spam = known_symbols[sym].seen_spam + 1
- else
- known_symbols[sym].seen_ham = known_symbols[sym].seen_ham + 1
- end
- end
- end
-
- -- Fill correlations
- for i = 1,nsym do
- for j = 1,nsym do
- if i ~= j then
- local sym = r['Symbols.Names'][i]
- local inner_sym_name = r['Symbols.Names'][j]
- local known_sym = known_symbols[sym]
- local inner_sym = known_symbols[inner_sym_name]
- if known_sym and inner_sym then
- if not correlations[known_sym.id] then
- correlations[known_sym.id] = {}
- end
- local n = correlations[known_sym.id][inner_sym.id] or 0
- n = n + 1
- correlations[known_sym.id][inner_sym.id] = n
- end
- end
- end
- end
- end
-
- local query_days = days_list(args.days)
- local conditions = {}
- table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id))
- local limit = ''
- local num_limit = tonumber(args.limit)
- if num_limit then
- limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
- end
- if args.where then
- table.insert(conditions, args.where)
- end
-
- local query_fmt = 'SELECT Action, Symbols.Names FROM rspamd WHERE %s%s'
- for _, query_day in ipairs(query_days) do
- -- Date should be the last condition
- table.insert(conditions, string.format("Date = '%s'", query_day))
- local query = string.format(query_fmt, table.concat(conditions, ' AND '), limit)
- local upstream = args.upstream:get_upstream_round_robin()
- local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
- if err ~= nil then
- io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
- os.exit(1)
- end
- conditions[#conditions] = nil -- remove Date condition
- end
-
- local remove = get_excluded_symbols(known_symbols, correlations, seen_total)
- if not args.json then
- for k in pairs(known_symbols) do
- if not remove[k] then
- io.stdout:write(string.format('%s\n', k))
- end
- end
- os.exit(0)
- end
-
- local json_output = {
- all_symbols = {},
- removed_symbols = {},
- used_symbols = {},
- }
- for k in pairs(known_symbols) do
- table.insert(json_output.all_symbols, k)
- local why_removed = remove[k]
- if why_removed then
- json_output.removed_symbols[k] = why_removed
- else
- table.insert(json_output.used_symbols, k)
- end
- end
- io.stdout:write(ucl.to_format(json_output, 'json'))
- end
-
- local function post_neural_training(url, rule, spam_rows, ham_rows)
- -- Prepare JSON payload
- local payload = ucl.to_format(
- {
- ham_vec = ham_rows,
- rule = rule,
- spam_vec = spam_rows,
- }, 'json')
-
- -- POST the payload
- local err, response = rspamd_http.request({
- body = payload,
- config = rspamd_config,
- ev_base = rspamadm_ev_base,
- log_obj = rspamd_config,
- resolver = rspamadm_dns_resolver,
- session = rspamadm_session,
- url = url,
- })
-
- if err then
- io.stderr:write(string.format('HTTP error: %s\n', err))
- os.exit(1)
- end
- if response.code ~= 200 then
- io.stderr:write(string.format('bad HTTP code: %d\n', response.code))
- os.exit(1)
- end
- io.stdout:write(string.format('%s\n', response.content))
- end
-
- local function handle_neural_train(args)
-
- local this_where -- which class of messages are we collecting data for
- local ham_rows, spam_rows = {}, {}
- local want_spam, want_ham = true, true -- keep collecting while true
-
- -- Try find profile in config
- local neural_opts = rspamd_config:get_all_opt('neural')
- local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile]
- if not symbols_profile then
- io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule))
- os.exit(1)
- end
- -- Try find max_trains
- local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000
-
- -- Callback used to process rows from Clickhouse
- local function process_row(r)
- local destination -- which table to collect this information in
- if this_where == args.ham then
- destination = ham_rows
- if #destination >= max_trains then
- want_ham = false
- return
- end
- else
- destination = spam_rows
- if #destination >= max_trains then
- want_spam = false
- return
- end
- end
- local ucl_parser = ucl.parser()
- local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack')
- if not ok then
- io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err))
- os.exit(1)
- end
- table.insert(destination, ucl_parser:get_object())
- end
-
- -- Generate symbols digest
- table.sort(symbols_profile)
- local symbols_digest = lua_util.table_digest(symbols_profile)
- -- Create list of days to query data for
- local query_days = days_list(args.days)
- -- Set value for limit
- local limit = ''
- local num_limit = tonumber(args.limit)
- if num_limit then
- limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
- end
- -- Prepare query elements
- local conditions = {string.format("%s = '%s'", args.column_name_digest, symbols_digest)}
- local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s'
-
- -- Run queries
- for _, the_where in ipairs({args.ham, args.spam}) do
- -- Inform callback which group of vectors we're collecting
- this_where = the_where
- table.insert(conditions, the_where) -- should be 2nd from last condition
- -- Loop over days and try collect data
- for _, query_day in ipairs(query_days) do
- -- Break the loop if we have enough data already
- if this_where == args.ham then
- if not want_ham then
- break
- end
- else
- if not want_spam then
- break
- end
- end
- -- Date should be the last condition
- table.insert(conditions, string.format("Date = '%s'", query_day))
- local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit)
- local upstream = args.upstream:get_upstream_round_robin()
- local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
- if err ~= nil then
- io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
- os.exit(1)
- end
- conditions[#conditions] = nil -- remove Date condition
- end
- conditions[#conditions] = nil -- remove spam/ham condition
- end
-
- -- Make sure we collected enough data for training
- if #ham_rows < max_trains then
- io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains))
- os.exit(1)
- end
- if #spam_rows < max_trains then
- io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains))
- os.exit(1)
- end
-
- return post_neural_training(args.url, args.rule, spam_rows, ham_rows)
- end
-
- local command_handlers = {
- neural_profile = handle_neural_profile,
- neural_train = handle_neural_train,
- }
-
- local function handler(args)
- local cmd_opts = parser:parse(args)
-
- load_config(cmd_opts.config_file)
- local cfg_opts = rspamd_config:get_all_opt('clickhouse')
-
- if cmd_opts.ask_password then
- local rspamd_util = require "rspamd_util"
-
- io.write('Password: ')
- cmd_opts.password = rspamd_util.readpassphrase()
- end
-
- local function override_settings(params)
- for _, which in ipairs(params) do
- if cmd_opts[which] == nil then
- cmd_opts[which] = cfg_opts[which]
- end
- end
- end
-
- override_settings({
- 'database', 'no_ssl_verify', 'password', 'server',
- 'use_gzip', 'use_https', 'user',
- })
-
- local servers = cmd_opts['server'] or cmd_opts['servers']
- if not servers then
- parser:error("server(s) unspecified & couldn't be fetched from config")
- end
-
- cmd_opts.upstream = rspamd_upstream_list.create(rspamd_config, servers, 8123)
-
- if not cmd_opts.upstream then
- io.stderr:write(string.format("can't parse clickhouse address: %s\n", servers))
- os.exit(1)
- end
-
- local f = command_handlers[cmd_opts.command]
- if not f then
- parser:error(string.format("command isn't implemented: %s",
- cmd_opts.command))
- end
- f(cmd_opts)
- end
-
- return {
- handler = handler,
- description = parser._description,
- name = 'clickhouse'
- }
|