]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Rework rescore tool to the new architecture
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 30 May 2018 12:39:31 +0000 (13:39 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 30 May 2018 12:39:31 +0000 (13:39 +0100)
lualib/rescore_utility.lua
lualib/rspamadm/corpus_test.lua
lualib/rspamadm/grep.lua
lualib/rspamadm/rescore.lua
src/rspamadm/CMakeLists.txt
src/rspamadm/commands.c
src/rspamadm/confighelp.c
src/rspamadm/rescore.c [deleted file]

index 2a9372d4e3f12f010f1d345c874b7bafc2c1de3b..268e814d858d03a109efd863fad9ee7eae0a4f6d 100644 (file)
@@ -47,20 +47,30 @@ function utility.read_log_file(file)
   return lines
 end
 
-function utility.get_all_logs(dir_path)
+function utility.get_all_logs(dirs)
   -- Reads all log files in the directory and returns a list of logs.
 
-  if dir_path:sub(#dir_path, #dir_path) == "/" then
-    dir_path = dir_path:sub(1, #dir_path -1)
+  if type(dirs) == 'string' then
+    dirs = {dirs}
   end
 
-  local files = rspamd_util.glob(dir_path .. "/*.log")
   local all_logs = {}
 
-  for _, file in pairs(files) do
-    local logs = utility.read_log_file(file)
-    for _, log_line in pairs(logs) do
-      all_logs[#all_logs + 1] = log_line
+  for _,dir in ipairs(dirs) do
+    if dir:sub(-1, -1) == "/" then
+      dir = dir:sub(1, -2)
+      local files = rspamd_util.glob(dir .. "/*.log")
+      for _, file in pairs(files) do
+        local logs = utility.read_log_file(file)
+        for _, log_line in pairs(logs) do
+          table.insert(all_logs, log_line)
+        end
+      end
+    else
+      local logs = utility.read_log_file(dir)
+      for _, log_line in pairs(logs) do
+        table.insert(all_logs, log_line)
+      end
     end
   end
 
@@ -160,8 +170,8 @@ function utility.generate_statistics_from_logs(logs, threshold)
       end
 
       -- Find slowest message
-      if (tonumber(log[#log-1]) > tonumber(file_stats.slowest)) then
-          file_stats.slowest = tostring(tonumber(log[#log-1]))
+      if ((tonumber(log[#log-1]) or 0) > file_stats.slowest) then
+          file_stats.slowest = tonumber(log[#log-1])
           file_stats.slowest_file = log[#log]
       end
     end
index 60aa2d7a9ceb243e927c8f0aae923bd6176ca757..07051a1960313f2f4d84e8a77e7164d64cf6bcf6 100644 (file)
@@ -17,6 +17,7 @@ parser:option "-S --spam"
 parser:option "-n --conns"
       :description("Number of parallel connections")
       :argname("<N>")
+      :convert(tonumber)
       :default(10)
 parser:option "-o --output"
       :description("Output file")
@@ -25,6 +26,7 @@ parser:option "-o --output"
 parser:option "-t --timeout"
       :description("Timeout for client connections")
       :argname("<sec>")
+      :convert(tonumber)
       :default(60)
 parser:option "-c --connect"
       :description("Connect to specific host")
index 0af83c1cfc2ed2e5375efafcfab5ae6ba4c01ff5..b149a0337ec647280bdb98d4d9dfb01dbb1956d1 100644 (file)
@@ -38,7 +38,6 @@ parser:argument "input":args "*"
       :default("stdin")
 parser:flag "-S --sensitive"
       :description('Enable case-sensitivity in string search')
-      :default("false")
 parser:flag "-o --orphans"
       :description('Print orphaned logs')
 parser:flag "-P --partial"
index e6df3b36408044e86ff2d9693053ab3123a3fa3b..80b9630f457b0e8344d16a72559e9c2d708de786 100644 (file)
@@ -1,12 +1,31 @@
+--[[
+Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+if not rspamd_config:has_torch() then
+  return
+end
+
 local torch = require "torch"
 local nn = require "nn"
 local lua_util = require "lua_util"
 local ucl = require "ucl"
 local logger = require "rspamd_logger"
-local getopt = require "getopt"
 local optim = require "optim"
 local rspamd_util = require "rspamd_util"
-
+local argparse = require "argparse"
 local rescore_utility = require "rescore_utility"
 
 local opts
@@ -15,6 +34,104 @@ local ignore_symbols = {
   ['DATE_IN_FUTURE'] = true,
 }
 
+local parser = argparse()
+    :name "rspamadm rescore"
+    :description "Estimate optimal symbol weights from log files"
+    :help_description_margin(37)
+
+parser:option "-l --log"
+      :description "Log file or files (from rescore)"
+      :argname("<log>")
+      :args "*"
+parser:option "-c --config"
+      :description "Path to config file"
+      :argname("<file>")
+      :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-o --output"
+      :description "Output file"
+      :argname("<file>")
+      :default("new.scores")
+parser:flag "-d --diff"
+      :description "Show differences in scores"
+parser:flag "-v --verbose"
+      :description "Verbose output"
+parser:flag "-z --freq"
+      :description "Display hit frequencies"
+parser:option "-i --iters"
+      :description "Learn iterations"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(10)
+parser:option "-b --batch"
+      :description "Batch size"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(100)
+parser:option "-d --decay"
+      :description "Decay rate"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(0.001)
+parser:option "-m --momentum"
+      :description "Learn momentum"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(0.1)
+parser:option "-t --threads"
+      :description "Number of threads to use"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(1)
+parser:option "-o --optim"
+      :description "Optimisation algorithm"
+      :argname("<alg>")
+      :convert {
+        LBFGS = "LBFGS",
+        ADAM = "ADAM",
+        ADAGRAD = "ADAGRAD",
+        SGD = "SGD",
+        NAG = "NAG"
+      }
+      :default "ADAM"
+parser:option "--ignore-symbol"
+      :description "Ignore symbol from logs"
+      :argname("<sym>")
+      :args "*"
+parser:option "--penalty-weight"
+      :description "Add new penalty weight to test"
+      :argname("<n>")
+      :convert(tonumber)
+      :args "*"
+parser:option "--learning-rate"
+      :description "Add new learning rate to test"
+      :argname("<n>")
+      :convert(tonumber)
+      :args "*"
+parser:option "--spam_action"
+      :description "Spam action"
+      :argname("<act>")
+      :default("reject")
+parser:option "--learning_rate_decay"
+      :description "Learn rate decay (for some algs)"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(0.0)
+parser:option "--weight_decay"
+      :description "Weight decay (for some algs)"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(0.0)
+parser:option "--l1"
+      :description "L1 regularization penalty"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(0.0)
+parser:option "--l2"
+      :description "L2 regularization penalty"
+      :argname("<n>")
+      :convert(tonumber)
+      :default(0.0)
+
 local function make_dataset_from_logs(logs, all_symbols, spam_score)
   -- Returns a list of {input, output} for torch SGD train
 
@@ -182,7 +299,8 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
 
   logs = update_logs(logs, new_symbol_scores)
 
-  local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold)
+  local file_stats, _, all_fps, all_fns =
+      rescore_utility.generate_statistics_from_logs(logs, threshold)
 
   return file_stats.fscore, all_fps, all_fns
 end
@@ -226,7 +344,7 @@ local function train(dataset, opt, model, criterion, epoch,
   local lbfgsState
   local sgdState
 
-  local batch_size = opt.batch_size
+  local batch_size = opt.batch
 
   logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size)
 
@@ -300,7 +418,7 @@ local function train(dataset, opt, model, criterion, epoch,
     end
 
     -- optimize on current mini-batch
-    if opt.optimization == 'LBFGS' then
+    if opt.optim == 'LBFGS' then
 
       -- Perform LBFGS step:
       lbfgsState = lbfgsState or {
@@ -315,7 +433,7 @@ local function train(dataset, opt, model, criterion, epoch,
       logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter)
       logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval)
 
-    elseif opt.optimization == 'ADAM' then
+    elseif opt.optim == 'ADAM' then
       sgdState = sgdState or {
         learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
         momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -323,7 +441,7 @@ local function train(dataset, opt, model, criterion, epoch,
         weightDecay = tonumber(opts.weight_decay),
       }
       optim.adam(feval, parameters, sgdState)
-    elseif opt.optimization == 'ADAGRAD' then
+    elseif opt.optim == 'ADAGRAD' then
       sgdState = sgdState or {
         learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
         momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -331,7 +449,7 @@ local function train(dataset, opt, model, criterion, epoch,
         weightDecay = tonumber(opts.weight_decay),
       }
       optim.adagrad(feval, parameters, sgdState)
-    elseif opt.optimization == 'SGD' then
+    elseif opt.optim == 'SGD' then
       sgdState = sgdState or {
         learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
         momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -339,7 +457,7 @@ local function train(dataset, opt, model, criterion, epoch,
         weightDecay = tonumber(opts.weight_decay),
       }
       optim.sgd(feval, parameters, sgdState)
-    elseif opt.optimization == 'NAG' then
+    elseif opt.optim == 'NAG' then
       sgdState = sgdState or {
         learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
         momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -348,7 +466,8 @@ local function train(dataset, opt, model, criterion, epoch,
       }
       optim.nag(feval, parameters, sgdState)
     else
-      error('unknown optimization method')
+      logger.errx('unknown optimization method: %s', opt.optim)
+      os.exit(1)
     end
   end
 
@@ -363,19 +482,6 @@ local function train(dataset, opt, model, criterion, epoch,
   confusion:zero()
 end
 
-
-local default_opts = {
-  verbose = true,
-  iters = 10,
-  threads = 1,
-  batch_size = 1000,
-  optimization = 'ADAM',
-  learning_rate_decay = 0.001,
-  momentum = 0.1,
-  l1 = 0.0,
-  l2 = 0.0,
-}
-
 local learning_rates = {
   0.01
 }
@@ -393,11 +499,27 @@ local function get_threshold()
       or actions['reject']), actions['reject']
 end
 
-return function (args, cfg)
-  opts = default_opts
-  opts = lua_util.override_defaults(opts, getopt.getopt(args, 'i:'))
+local function handler(args)
+  opts = parser:parse(args)
+  if not opts['log'] then
+    parser:error('no log specified')
+  end
+
+  local _r,err = rspamd_config:load_ucl(opts['config'])
+
+  if not _r then
+    logger.errx('cannot parse %s: %s', opts['config'], err)
+    os.exit(1)
+  end
+
+  _r,err = rspamd_config:parse_rcl({'logging', 'worker'})
+  if not _r then
+    logger.errx('cannot process %s: %s', opts['config'], err)
+    os.exit(1)
+  end
+
   local threshold,reject_score = get_threshold()
-  local logs = rescore_utility.get_all_logs(cfg["logdir"])
+  local logs = rescore_utility.get_all_logs(opts['log'])
 
   if opts['ignore-symbol'] then
     local function add_ignore(s)
@@ -446,14 +568,12 @@ return function (args, cfg)
     end
   end
 
-  if opts['i'] then opts['iters'] = opts['i'] end
-
   local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
   local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
       ignore_symbols)
 
   -- Display hit frequencies
-  if opts['z'] then
+  if opts['freq'] then
       local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
       local t = {}
       for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
@@ -580,11 +700,11 @@ return function (args, cfg)
 
   new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
 
-  if cfg["output"] then
-    write_scores(new_symbol_scores, cfg["output"])
+  if opts["output"] then
+    write_scores(new_symbol_scores, opts["output"])
   end
 
-  if cfg["diff"] then
+  if opts["diff"] then
     print_score_diff(new_symbol_scores, original_symbol_scores)
   end
 
@@ -616,3 +736,10 @@ return function (args, cfg)
       end
   end
 end
+
+
+return {
+  handler = handler,
+  description = parser._description,
+  name = 'rescore'
+}
\ No newline at end of file
index 616c40322df7a11abc0c8084d1490cc48993b934..3559186f431d5bcc7e28978351b3863f04bb9710 100644 (file)
@@ -13,7 +13,6 @@ SET(RSPAMADMSRC rspamadm.c
         signtool.c
         lua_repl.c
         dkim_keygen.c
-        rescore.c
         ${CMAKE_BINARY_DIR}/src/workers.c
         ${CMAKE_BINARY_DIR}/src/modules.c
         ${CMAKE_SOURCE_DIR}/src/controller.c
index 306938697d754fa75ae203ae561e074b48b0b4c0..224cc48d34e3358c392fb8ca92c9c64416da6523 100644 (file)
@@ -31,7 +31,6 @@ extern struct rspamadm_command signtool_command;
 extern struct rspamadm_command lua_command;
 extern struct rspamadm_command dkim_keygen_command;
 extern struct rspamadm_command configwizard_command;
-extern struct rspamadm_command rescore_command;
 
 const struct rspamadm_command *commands[] = {
        &help_command,
@@ -48,7 +47,6 @@ const struct rspamadm_command *commands[] = {
        &lua_command,
        &dkim_keygen_command,
        &configwizard_command,
-       &rescore_command,
        NULL
 };
 
index 7ce72dea98803de10b60578fa0da3dc07d472edb..08789a468759806a3899b2c80e847b8406d49f05 100644 (file)
@@ -231,7 +231,7 @@ rspamadm_confighelp (gint argc, gchar **argv, const struct rspamadm_command *cmd
        cfg->compiled_modules = modules;
        cfg->compiled_workers = workers;
 
-       rspamd_rcl_config_init (cfg);
+       rspamd_rcl_config_init (cfg, NULL);
        lua_pushboolean (cfg->lua_state, true);
        lua_setglobal (cfg->lua_state, "confighelp");
        rspamd_rcl_add_lua_plugins_path (cfg, plugins_path, NULL);
diff --git a/src/rspamadm/rescore.c b/src/rspamadm/rescore.c
deleted file mode 100644 (file)
index 346c49c..0000000
+++ /dev/null
@@ -1,191 +0,0 @@
-/*-
- * Copyright 2017 Pragadeesh C
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "config.h"
-#include "rspamadm.h"
-#include "lua/lua_common.h"
-
-static gchar *logdir = NULL;
-static gchar *output = "new.scores";
-static gboolean score_diff = false;  /* Print score diff flag */
-static gchar *config = NULL;
-extern struct rspamd_main *rspamd_main;
-/* Defined in modules.c */
-extern module_t *modules[];
-extern worker_t *workers[];
-
-static void rspamadm_rescore (gint argc, gchar **argv,
-                                                         const struct rspamadm_command *cmd);
-
-static const char *rspamadm_rescore_help (gboolean full_help,
-                                                                                 const struct rspamadm_command *cmd);
-
-struct rspamadm_command rescore_command = {
-               .name = "rescore",
-               .flags = 0,
-               .help = rspamadm_rescore_help,
-               .run = rspamadm_rescore
-};
-
-static GOptionEntry entries[] = {
-               {"logdir", 'l', 0, G_OPTION_ARG_FILENAME, &logdir,
-                               "Logs directory",                               NULL},
-               {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output,
-                               "Scores output location",                       NULL},
-               {"diff",   'd', 0, G_OPTION_ARG_NONE,     &score_diff,
-                               "Print score diff",                             NULL},
-               {"config", 'c', 0, G_OPTION_ARG_STRING, &config,
-                               "Config file to use",     NULL},
-               {NULL,     0,   0, G_OPTION_ARG_NONE, NULL, NULL,       NULL}
-};
-
-static void
-config_logger (rspamd_mempool_t *pool, gpointer ud)
-{
-       struct rspamd_main *rm = ud;
-
-       rm->cfg->log_type = RSPAMD_LOG_CONSOLE;
-       rm->cfg->log_level = G_LOG_LEVEL_MESSAGE;
-
-       rspamd_set_logger (rm->cfg, g_quark_try_string ("main"), &rm->logger,
-                       rm->server_pool);
-
-       if (rspamd_log_open_priv (rm->logger, rm->workers_uid, rm->workers_gid) ==
-                       -1) {
-               fprintf (stderr, "Fatal error, cannot open logfile, exiting\n");
-               exit (EXIT_FAILURE);
-       }
-}
-
-static const char *
-rspamadm_rescore_help (gboolean full_help, const struct rspamadm_command *cmd)
-{
-       const char *help_str;
-
-       if (full_help) {
-               help_str = "Estimate optimal symbol weights from log files\n\n"
-                               "Usage: rspamadm rescore -l <log_directory>\n"
-                               "Where options are:\n\n"
-                               "-l: path to logs directory\n"
-                               "-o: scores output file location\n"
-                               "-d: print scores diff\n"
-                               "-i: max iterations for perceptron\n";
-       } else {
-               help_str = "Estimate optimal symbol weights from log files";
-       }
-
-       return help_str;
-}
-
-static void
-rspamadm_rescore (gint argc, gchar **argv, const struct rspamadm_command *cmd)
-{
-       GOptionContext *context;
-       GError *error = NULL;
-       struct rspamd_config *cfg = rspamd_main->cfg, **pcfg;
-       gboolean ret = TRUE;
-       worker_t **pworker;
-       const gchar *confdir;
-
-#ifndef WITH_TORCH
-       rspamd_fprintf (stderr, "Torch is not enabled. "
-                               "Use -DENABLE_TORCH=ON option while running cmake.\n");
-       exit (EXIT_FAILURE);
-#endif
-
-       context = g_option_context_new (
-                       "rescore - estimate optimal symbol weights from log files");
-
-       g_option_context_set_summary (context,
-                       "Summary:\n Rspamd administration utility version "
-                                       RVERSION
-                                       "\n Release id: "
-                                       RID);
-
-       g_option_context_add_main_entries (context, entries, NULL);
-       g_option_context_set_ignore_unknown_options (context, TRUE);
-
-       if (!g_option_context_parse (context, &argc, &argv, &error)) {
-               rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message);
-               g_error_free (error);
-               exit (EXIT_FAILURE);
-       }
-
-       if (logdir == NULL) {
-               rspamd_fprintf (stderr, "Please specify log directory.\n");
-               exit (EXIT_FAILURE);
-       }
-
-       if (config == NULL) {
-               if ((confdir = g_hash_table_lookup (ucl_vars, "CONFDIR")) == NULL) {
-                       confdir = RSPAMD_CONFDIR;
-               }
-
-               config = g_strdup_printf ("%s%c%s", confdir, G_DIR_SEPARATOR,
-                               "rspamd.conf");
-       }
-
-       pworker = &workers[0];
-       while (*pworker) {
-               /* Init string quarks */
-               (void) g_quark_from_static_string ((*pworker)->name);
-               pworker++;
-       }
-
-       cfg->cache = rspamd_symbols_cache_new (cfg);
-       cfg->compiled_modules = modules;
-       cfg->compiled_workers = workers;
-       cfg->cfg_name = config;
-
-       if (!rspamd_config_read (cfg, cfg->cfg_name, config_logger, rspamd_main, ucl_vars)) {
-               ret = FALSE;
-       }
-       else {
-               /* Do post-load actions */
-               rspamd_lua_post_load_config (cfg);
-
-               if (!rspamd_init_filters (cfg, FALSE)) {
-                       ret = FALSE;
-               }
-
-               if (ret) {
-                       ret = rspamd_config_post_load (cfg, RSPAMD_CONFIG_INIT_SYMCACHE);
-                       rspamd_symbols_cache_validate (cfg->cache,
-                                       cfg,
-                                       FALSE);
-               }
-       }
-
-       if (ret) {
-               ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (cfg->cfg_name),
-                               "config_path", 0, false);
-               ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (logdir),
-                               "logdir", 0, false);
-               ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (output),
-                               "output", 0, false);
-               ucl_object_insert_key (cfg->rcl_obj, ucl_object_frombool (score_diff),
-                               "diff", 0, false);
-               pcfg = lua_newuserdata (L, sizeof (struct rspamd_config *));
-               rspamd_lua_setclass (L, "rspamd{config}", -1);
-               *pcfg = cfg;
-               lua_setglobal (L, "rspamd_config");
-               rspamadm_execute_lua_ucl_subr (L,
-                               argc,
-                               argv,
-                               cfg->rcl_obj,
-                               "rescore");
-       }
-}