]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] rspamadm clickhouse neural_train subcommand 3608/head
authorAndrew Lewis <nerf@judo.za.org>
Fri, 15 Jan 2021 09:46:40 +0000 (11:46 +0200)
committerAndrew Lewis <nerf@judo.za.org>
Fri, 15 Jan 2021 09:46:40 +0000 (11:46 +0200)
lualib/rspamadm/clickhouse.lua

index 2ca4eab18468bb18aa03c847c279f5821cbefd6f..711437c9449f928d373e112cbba30af9c141eb43 100644 (file)
@@ -16,9 +16,13 @@ limitations under the License.
 
 local argparse = require "argparse"
 local lua_clickhouse = require "lua_clickhouse"
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
 local rspamd_upstream_list = require "rspamd_upstream_list"
 local ucl = require "ucl"
 
+local E = {}
+
 -- Define command line options
 local parser = argparse()
     :name 'rspamadm clickhouse'
@@ -80,6 +84,44 @@ neural_profile:option '--settings-id'
       :argname('settings_id')
       :default('')
 
+local neural_train = parser:command 'neural_train'
+      :description 'Train neural using data from Clickhouse'
+neural_train:option '--days'
+      :description 'Number of days to query data for'
+      :argname('days')
+      :default('7')
+neural_train:option '--column-name-digest'
+      :description 'Name of neural profile digest column in Clickhouse'
+      :argname('column_name_digest')
+      :default('NeuralDigest')
+neural_train:option '--column-name-vector'
+      :description 'Name of neural training vector column in Clickhouse'
+      :argname('column_name_vector')
+      :default('NeuralMpack')
+neural_train:option '--limit -l'
+      :description 'Maximum rows to fetch per day'
+      :argname('limit')
+neural_train:option '--profile -p'
+      :description 'Profile to use for training'
+      :argname('profile')
+      :default('default')
+neural_train:option '--rule -r'
+      :description 'Rule to train'
+      :argname('rule')
+      :default('default')
+neural_train:option '--spam -s'
+      :description 'WHERE clause to use for spam'
+      :argname('spam')
+      :default("Action == 'reject'")
+neural_train:option '--ham -h'
+      :description 'WHERE clause to use for ham'
+      :argname('ham')
+      :default('Score < 0')
+neural_train:option '--url -u'
+      :description 'URL to use for training'
+      :argname('url')
+      :default('http://127.0.0.1:11334/plugins/neural/learn')
+
 local http_params = {
   config = rspamd_config,
   ev_base = rspamadm_ev_base,
@@ -97,6 +139,18 @@ local function load_config(config_file)
   end
 end
 
+local function days_list(days)
+  -- Create list of days to query starting with yesterday
+  local query_days = {}
+  local previous_date = os.time() - 86400
+  local num_days = tonumber(days)
+  for _ = 1, num_days do
+    table.insert(query_days, os.date('%Y-%m-%d', previous_date))
+    previous_date = previous_date - 86400
+  end
+  return query_days
+end
+
 local function get_excluded_symbols(known_symbols, correlations, seen_total)
   -- Walk results once to collect all symbols & count ocurrences
 
@@ -202,15 +256,7 @@ local function handle_neural_profile(args)
     end
   end
 
-  -- Create list of days to query starting with yesterday
-  local query_days = {}
-  local previous_date = os.time() - 86400
-  local num_days = tonumber(args.days)
-  for _ = 1, num_days do
-    table.insert(query_days, os.date('%Y-%m-%d', previous_date))
-    previous_date = previous_date - 86400
-  end
-
+  local query_days = days_list(args.days)
   local conditions = {}
   table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id))
   local limit = ''
@@ -263,8 +309,139 @@ local function handle_neural_profile(args)
   io.stdout:write(ucl.to_format(json_output, 'json'))
 end
 
+local function post_neural_training(url, rule, spam_rows, ham_rows)
+  -- Prepare JSON payload
+  local payload = ucl.to_format(
+    {
+      ham_vec = ham_rows,
+      rule = rule,
+      spam_vec = spam_rows,
+    }, 'json')
+
+  -- POST the payload
+  local err, response = rspamd_http.request({
+    body = payload,
+    config = rspamd_config,
+    ev_base = rspamadm_ev_base,
+    log_obj = rspamd_config,
+    resolver = rspamadm_dns_resolver,
+    session = rspamadm_session,
+    url = url,
+  })
+
+  if err then
+    io.stderr:write(string.format('HTTP error: %s\n', err))
+    os.exit(1)
+  end
+  if response.code ~= 200 then
+    io.stderr:write(string.format('bad HTTP code: %d\n', response.code))
+    os.exit(1)
+  end
+  io.stdout:write(string.format('%s\n', response.content))
+end
+
+local function handle_neural_train(args)
+
+  local this_where -- which class of messages are we collecting data for
+  local ham_rows, spam_rows = {}, {}
+  local want_spam, want_ham = true, true -- keep collecting while true
+  local ucl_parser = ucl.parser()
+
+  -- Try find profile in config
+  local neural_opts = rspamd_config:get_all_opt('neural')
+  local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile]
+  if not symbols_profile then
+    io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule))
+    os.exit(1)
+  end
+  -- Try find max_trains
+  local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000
+
+  -- Callback used to process rows from Clickhouse
+  local function process_row(r)
+    local destination -- which table to collect this information in
+    if this_where == args.ham then
+      destination = ham_rows
+      if #destination >= max_trains then
+        want_ham = false
+        return
+      end
+    else
+      destination = spam_rows
+      if #destination >= max_trains then
+        want_spam = false
+        return
+      end
+    end
+    local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack')
+    if not ok then
+      io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err))
+      os.exit(1)
+    end
+    table.insert(destination, ucl_parser:get_object())
+  end
+
+  -- Generate symbols digest
+  local symbols_digest = lua_util.table_digest(symbols_profile)
+  -- Create list of days to query data for
+  local query_days = days_list(args.days)
+  -- Set value for limit
+  local limit = ''
+  local num_limit = tonumber(args.limit)
+  if num_limit then
+    limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
+  end
+  -- Prepare query elements
+  local conditions = {string.format("%s = '%s'", args.column_name_digest, symbols_digest)}
+  local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s'
+
+  -- Run queries
+  for _, the_where in ipairs({args.ham, args.spam}) do
+    -- Inform callback which group of vectors we're collecting
+    this_where = the_where
+    table.insert(conditions, the_where) -- should be 2nd from last condition
+    -- Loop over days and try collect data
+    for _, query_day in ipairs(query_days) do
+      -- Break the loop if we have enough data already
+      if this_where == args.ham then
+        if not want_ham then
+          break
+       end
+      else
+        if not want_spam then
+          break
+        end
+      end
+      -- Date should be the last condition
+      table.insert(conditions, string.format("Date = '%s'", query_day))
+      local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit)
+      local upstream = args.upstream:get_upstream_round_robin()
+      local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
+      if err ~= nil then
+        io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
+        os.exit(1)
+      end
+      conditions[#conditions] = nil -- remove Date condition
+    end
+    conditions[#conditions] = nil -- remove spam/ham condition
+  end
+
+  -- Make sure we collected enough data for training
+  if #ham_rows < max_trains then
+    io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains))
+    os.exit(1)
+  end
+  if #spam_rows < max_trains then
+    io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains))
+    os.exit(1)
+  end
+
+  return post_neural_training(args.url, args.rule, spam_rows, ham_rows)
+end
+
 local command_handlers = {
   neural_profile = handle_neural_profile,
+  neural_train = handle_neural_train,
 }
 
 local function handler(args)