From d927a3a73c751adc21694262f0c931b8a6a55372 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 30 May 2018 13:39:31 +0100 Subject: [PATCH] [Project] Rework rescore tool to the new architecture --- lualib/rescore_utility.lua | 30 +++-- lualib/rspamadm/corpus_test.lua | 2 + lualib/rspamadm/grep.lua | 1 - lualib/rspamadm/rescore.lua | 193 ++++++++++++++++++++++++++------ src/rspamadm/CMakeLists.txt | 1 - src/rspamadm/commands.c | 2 - src/rspamadm/confighelp.c | 2 +- src/rspamadm/rescore.c | 191 ------------------------------- 8 files changed, 183 insertions(+), 239 deletions(-) delete mode 100644 src/rspamadm/rescore.c diff --git a/lualib/rescore_utility.lua b/lualib/rescore_utility.lua index 2a9372d4e..268e814d8 100644 --- a/lualib/rescore_utility.lua +++ b/lualib/rescore_utility.lua @@ -47,20 +47,30 @@ function utility.read_log_file(file) return lines end -function utility.get_all_logs(dir_path) +function utility.get_all_logs(dirs) -- Reads all log files in the directory and returns a list of logs. - if dir_path:sub(#dir_path, #dir_path) == "/" then - dir_path = dir_path:sub(1, #dir_path -1) + if type(dirs) == 'string' then + dirs = {dirs} end - local files = rspamd_util.glob(dir_path .. "/*.log") local all_logs = {} - for _, file in pairs(files) do - local logs = utility.read_log_file(file) - for _, log_line in pairs(logs) do - all_logs[#all_logs + 1] = log_line + for _,dir in ipairs(dirs) do + if dir:sub(-1, -1) == "/" then + dir = dir:sub(1, -2) + local files = rspamd_util.glob(dir .. "/*.log") + for _, file in pairs(files) do + local logs = utility.read_log_file(file) + for _, log_line in pairs(logs) do + table.insert(all_logs, log_line) + end + end + else + local logs = utility.read_log_file(dir) + for _, log_line in pairs(logs) do + table.insert(all_logs, log_line) + end end end @@ -160,8 +170,8 @@ function utility.generate_statistics_from_logs(logs, threshold) end -- Find slowest message - if (tonumber(log[#log-1]) > tonumber(file_stats.slowest)) then - file_stats.slowest = tostring(tonumber(log[#log-1])) + if ((tonumber(log[#log-1]) or 0) > file_stats.slowest) then + file_stats.slowest = tonumber(log[#log-1]) file_stats.slowest_file = log[#log] end end diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua index 60aa2d7a9..07051a196 100644 --- a/lualib/rspamadm/corpus_test.lua +++ b/lualib/rspamadm/corpus_test.lua @@ -17,6 +17,7 @@ parser:option "-S --spam" parser:option "-n --conns" :description("Number of parallel connections") :argname("") + :convert(tonumber) :default(10) parser:option "-o --output" :description("Output file") @@ -25,6 +26,7 @@ parser:option "-o --output" parser:option "-t --timeout" :description("Timeout for client connections") :argname("") + :convert(tonumber) :default(60) parser:option "-c --connect" :description("Connect to specific host") diff --git a/lualib/rspamadm/grep.lua b/lualib/rspamadm/grep.lua index 0af83c1cf..b149a0337 100644 --- a/lualib/rspamadm/grep.lua +++ b/lualib/rspamadm/grep.lua @@ -38,7 +38,6 @@ parser:argument "input":args "*" :default("stdin") parser:flag "-S --sensitive" :description('Enable case-sensitivity in string search') - :default("false") parser:flag "-o --orphans" :description('Print orphaned logs') parser:flag "-P --partial" diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index e6df3b364..80b9630f4 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -1,12 +1,31 @@ +--[[ +Copyright (c) 2018, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if not rspamd_config:has_torch() then + return +end + local torch = require "torch" local nn = require "nn" local lua_util = require "lua_util" local ucl = require "ucl" local logger = require "rspamd_logger" -local getopt = require "getopt" local optim = require "optim" local rspamd_util = require "rspamd_util" - +local argparse = require "argparse" local rescore_utility = require "rescore_utility" local opts @@ -15,6 +34,104 @@ local ignore_symbols = { ['DATE_IN_FUTURE'] = true, } +local parser = argparse() + :name "rspamadm rescore" + :description "Estimate optimal symbol weights from log files" + :help_description_margin(37) + +parser:option "-l --log" + :description "Log file or files (from rescore)" + :argname("") + :args "*" +parser:option "-c --config" + :description "Path to config file" + :argname("") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:option "-o --output" + :description "Output file" + :argname("") + :default("new.scores") +parser:flag "-d --diff" + :description "Show differences in scores" +parser:flag "-v --verbose" + :description "Verbose output" +parser:flag "-z --freq" + :description "Display hit frequencies" +parser:option "-i --iters" + :description "Learn iterations" + :argname("") + :convert(tonumber) + :default(10) +parser:option "-b --batch" + :description "Batch size" + :argname("") + :convert(tonumber) + :default(100) +parser:option "-d --decay" + :description "Decay rate" + :argname("") + :convert(tonumber) + :default(0.001) +parser:option "-m --momentum" + :description "Learn momentum" + :argname("") + :convert(tonumber) + :default(0.1) +parser:option "-t --threads" + :description "Number of threads to use" + :argname("") + :convert(tonumber) + :default(1) +parser:option "-o --optim" + :description "Optimisation algorithm" + :argname("") + :convert { + LBFGS = "LBFGS", + ADAM = "ADAM", + ADAGRAD = "ADAGRAD", + SGD = "SGD", + NAG = "NAG" + } + :default "ADAM" +parser:option "--ignore-symbol" + :description "Ignore symbol from logs" + :argname("") + :args "*" +parser:option "--penalty-weight" + :description "Add new penalty weight to test" + :argname("") + :convert(tonumber) + :args "*" +parser:option "--learning-rate" + :description "Add new learning rate to test" + :argname("") + :convert(tonumber) + :args "*" +parser:option "--spam_action" + :description "Spam action" + :argname("") + :default("reject") +parser:option "--learning_rate_decay" + :description "Learn rate decay (for some algs)" + :argname("") + :convert(tonumber) + :default(0.0) +parser:option "--weight_decay" + :description "Weight decay (for some algs)" + :argname("") + :convert(tonumber) + :default(0.0) +parser:option "--l1" + :description "L1 regularization penalty" + :argname("") + :convert(tonumber) + :default(0.0) +parser:option "--l2" + :description "L2 regularization penalty" + :argname("") + :convert(tonumber) + :default(0.0) + local function make_dataset_from_logs(logs, all_symbols, spam_score) -- Returns a list of {input, output} for torch SGD train @@ -182,7 +299,8 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho logs = update_logs(logs, new_symbol_scores) - local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold) + local file_stats, _, all_fps, all_fns = + rescore_utility.generate_statistics_from_logs(logs, threshold) return file_stats.fscore, all_fps, all_fns end @@ -226,7 +344,7 @@ local function train(dataset, opt, model, criterion, epoch, local lbfgsState local sgdState - local batch_size = opt.batch_size + local batch_size = opt.batch logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size) @@ -300,7 +418,7 @@ local function train(dataset, opt, model, criterion, epoch, end -- optimize on current mini-batch - if opt.optimization == 'LBFGS' then + if opt.optim == 'LBFGS' then -- Perform LBFGS step: lbfgsState = lbfgsState or { @@ -315,7 +433,7 @@ local function train(dataset, opt, model, criterion, epoch, logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter) logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval) - elseif opt.optimization == 'ADAM' then + elseif opt.optim == 'ADAM' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -323,7 +441,7 @@ local function train(dataset, opt, model, criterion, epoch, weightDecay = tonumber(opts.weight_decay), } optim.adam(feval, parameters, sgdState) - elseif opt.optimization == 'ADAGRAD' then + elseif opt.optim == 'ADAGRAD' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -331,7 +449,7 @@ local function train(dataset, opt, model, criterion, epoch, weightDecay = tonumber(opts.weight_decay), } optim.adagrad(feval, parameters, sgdState) - elseif opt.optimization == 'SGD' then + elseif opt.optim == 'SGD' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -339,7 +457,7 @@ local function train(dataset, opt, model, criterion, epoch, weightDecay = tonumber(opts.weight_decay), } optim.sgd(feval, parameters, sgdState) - elseif opt.optimization == 'NAG' then + elseif opt.optim == 'NAG' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -348,7 +466,8 @@ local function train(dataset, opt, model, criterion, epoch, } optim.nag(feval, parameters, sgdState) else - error('unknown optimization method') + logger.errx('unknown optimization method: %s', opt.optim) + os.exit(1) end end @@ -363,19 +482,6 @@ local function train(dataset, opt, model, criterion, epoch, confusion:zero() end - -local default_opts = { - verbose = true, - iters = 10, - threads = 1, - batch_size = 1000, - optimization = 'ADAM', - learning_rate_decay = 0.001, - momentum = 0.1, - l1 = 0.0, - l2 = 0.0, -} - local learning_rates = { 0.01 } @@ -393,11 +499,27 @@ local function get_threshold() or actions['reject']), actions['reject'] end -return function (args, cfg) - opts = default_opts - opts = lua_util.override_defaults(opts, getopt.getopt(args, 'i:')) +local function handler(args) + opts = parser:parse(args) + if not opts['log'] then + parser:error('no log specified') + end + + local _r,err = rspamd_config:load_ucl(opts['config']) + + if not _r then + logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r,err = rspamd_config:parse_rcl({'logging', 'worker'}) + if not _r then + logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end + local threshold,reject_score = get_threshold() - local logs = rescore_utility.get_all_logs(cfg["logdir"]) + local logs = rescore_utility.get_all_logs(opts['log']) if opts['ignore-symbol'] then local function add_ignore(s) @@ -446,14 +568,12 @@ return function (args, cfg) end end - if opts['i'] then opts['iters'] = opts['i'] end - local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols) local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config, ignore_symbols) -- Display hit frequencies - if opts['z'] then + if opts['freq'] then local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold) local t = {} for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end @@ -580,11 +700,11 @@ return function (args, cfg) new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) - if cfg["output"] then - write_scores(new_symbol_scores, cfg["output"]) + if opts["output"] then + write_scores(new_symbol_scores, opts["output"]) end - if cfg["diff"] then + if opts["diff"] then print_score_diff(new_symbol_scores, original_symbol_scores) end @@ -616,3 +736,10 @@ return function (args, cfg) end end end + + +return { + handler = handler, + description = parser._description, + name = 'rescore' +} \ No newline at end of file diff --git a/src/rspamadm/CMakeLists.txt b/src/rspamadm/CMakeLists.txt index 616c40322..3559186f4 100644 --- a/src/rspamadm/CMakeLists.txt +++ b/src/rspamadm/CMakeLists.txt @@ -13,7 +13,6 @@ SET(RSPAMADMSRC rspamadm.c signtool.c lua_repl.c dkim_keygen.c - rescore.c ${CMAKE_BINARY_DIR}/src/workers.c ${CMAKE_BINARY_DIR}/src/modules.c ${CMAKE_SOURCE_DIR}/src/controller.c diff --git a/src/rspamadm/commands.c b/src/rspamadm/commands.c index 306938697..224cc48d3 100644 --- a/src/rspamadm/commands.c +++ b/src/rspamadm/commands.c @@ -31,7 +31,6 @@ extern struct rspamadm_command signtool_command; extern struct rspamadm_command lua_command; extern struct rspamadm_command dkim_keygen_command; extern struct rspamadm_command configwizard_command; -extern struct rspamadm_command rescore_command; const struct rspamadm_command *commands[] = { &help_command, @@ -48,7 +47,6 @@ const struct rspamadm_command *commands[] = { &lua_command, &dkim_keygen_command, &configwizard_command, - &rescore_command, NULL }; diff --git a/src/rspamadm/confighelp.c b/src/rspamadm/confighelp.c index 7ce72dea9..08789a468 100644 --- a/src/rspamadm/confighelp.c +++ b/src/rspamadm/confighelp.c @@ -231,7 +231,7 @@ rspamadm_confighelp (gint argc, gchar **argv, const struct rspamadm_command *cmd cfg->compiled_modules = modules; cfg->compiled_workers = workers; - rspamd_rcl_config_init (cfg); + rspamd_rcl_config_init (cfg, NULL); lua_pushboolean (cfg->lua_state, true); lua_setglobal (cfg->lua_state, "confighelp"); rspamd_rcl_add_lua_plugins_path (cfg, plugins_path, NULL); diff --git a/src/rspamadm/rescore.c b/src/rspamadm/rescore.c deleted file mode 100644 index 346c49c68..000000000 --- a/src/rspamadm/rescore.c +++ /dev/null @@ -1,191 +0,0 @@ -/*- - * Copyright 2017 Pragadeesh C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "config.h" -#include "rspamadm.h" -#include "lua/lua_common.h" - -static gchar *logdir = NULL; -static gchar *output = "new.scores"; -static gboolean score_diff = false; /* Print score diff flag */ -static gchar *config = NULL; -extern struct rspamd_main *rspamd_main; -/* Defined in modules.c */ -extern module_t *modules[]; -extern worker_t *workers[]; - -static void rspamadm_rescore (gint argc, gchar **argv, - const struct rspamadm_command *cmd); - -static const char *rspamadm_rescore_help (gboolean full_help, - const struct rspamadm_command *cmd); - -struct rspamadm_command rescore_command = { - .name = "rescore", - .flags = 0, - .help = rspamadm_rescore_help, - .run = rspamadm_rescore -}; - -static GOptionEntry entries[] = { - {"logdir", 'l', 0, G_OPTION_ARG_FILENAME, &logdir, - "Logs directory", NULL}, - {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output, - "Scores output location", NULL}, - {"diff", 'd', 0, G_OPTION_ARG_NONE, &score_diff, - "Print score diff", NULL}, - {"config", 'c', 0, G_OPTION_ARG_STRING, &config, - "Config file to use", NULL}, - {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL} -}; - -static void -config_logger (rspamd_mempool_t *pool, gpointer ud) -{ - struct rspamd_main *rm = ud; - - rm->cfg->log_type = RSPAMD_LOG_CONSOLE; - rm->cfg->log_level = G_LOG_LEVEL_MESSAGE; - - rspamd_set_logger (rm->cfg, g_quark_try_string ("main"), &rm->logger, - rm->server_pool); - - if (rspamd_log_open_priv (rm->logger, rm->workers_uid, rm->workers_gid) == - -1) { - fprintf (stderr, "Fatal error, cannot open logfile, exiting\n"); - exit (EXIT_FAILURE); - } -} - -static const char * -rspamadm_rescore_help (gboolean full_help, const struct rspamadm_command *cmd) -{ - const char *help_str; - - if (full_help) { - help_str = "Estimate optimal symbol weights from log files\n\n" - "Usage: rspamadm rescore -l \n" - "Where options are:\n\n" - "-l: path to logs directory\n" - "-o: scores output file location\n" - "-d: print scores diff\n" - "-i: max iterations for perceptron\n"; - } else { - help_str = "Estimate optimal symbol weights from log files"; - } - - return help_str; -} - -static void -rspamadm_rescore (gint argc, gchar **argv, const struct rspamadm_command *cmd) -{ - GOptionContext *context; - GError *error = NULL; - struct rspamd_config *cfg = rspamd_main->cfg, **pcfg; - gboolean ret = TRUE; - worker_t **pworker; - const gchar *confdir; - -#ifndef WITH_TORCH - rspamd_fprintf (stderr, "Torch is not enabled. " - "Use -DENABLE_TORCH=ON option while running cmake.\n"); - exit (EXIT_FAILURE); -#endif - - context = g_option_context_new ( - "rescore - estimate optimal symbol weights from log files"); - - g_option_context_set_summary (context, - "Summary:\n Rspamd administration utility version " - RVERSION - "\n Release id: " - RID); - - g_option_context_add_main_entries (context, entries, NULL); - g_option_context_set_ignore_unknown_options (context, TRUE); - - if (!g_option_context_parse (context, &argc, &argv, &error)) { - rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message); - g_error_free (error); - exit (EXIT_FAILURE); - } - - if (logdir == NULL) { - rspamd_fprintf (stderr, "Please specify log directory.\n"); - exit (EXIT_FAILURE); - } - - if (config == NULL) { - if ((confdir = g_hash_table_lookup (ucl_vars, "CONFDIR")) == NULL) { - confdir = RSPAMD_CONFDIR; - } - - config = g_strdup_printf ("%s%c%s", confdir, G_DIR_SEPARATOR, - "rspamd.conf"); - } - - pworker = &workers[0]; - while (*pworker) { - /* Init string quarks */ - (void) g_quark_from_static_string ((*pworker)->name); - pworker++; - } - - cfg->cache = rspamd_symbols_cache_new (cfg); - cfg->compiled_modules = modules; - cfg->compiled_workers = workers; - cfg->cfg_name = config; - - if (!rspamd_config_read (cfg, cfg->cfg_name, config_logger, rspamd_main, ucl_vars)) { - ret = FALSE; - } - else { - /* Do post-load actions */ - rspamd_lua_post_load_config (cfg); - - if (!rspamd_init_filters (cfg, FALSE)) { - ret = FALSE; - } - - if (ret) { - ret = rspamd_config_post_load (cfg, RSPAMD_CONFIG_INIT_SYMCACHE); - rspamd_symbols_cache_validate (cfg->cache, - cfg, - FALSE); - } - } - - if (ret) { - ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (cfg->cfg_name), - "config_path", 0, false); - ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (logdir), - "logdir", 0, false); - ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (output), - "output", 0, false); - ucl_object_insert_key (cfg->rcl_obj, ucl_object_frombool (score_diff), - "diff", 0, false); - pcfg = lua_newuserdata (L, sizeof (struct rspamd_config *)); - rspamd_lua_setclass (L, "rspamd{config}", -1); - *pcfg = cfg; - lua_setglobal (L, "rspamd_config"); - rspamadm_execute_lua_ucl_subr (L, - argc, - argv, - cfg->rcl_obj, - "rescore"); - } -} -- 2.39.5