aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lualib/rescore_utility.lua30
-rw-r--r--lualib/rspamadm/corpus_test.lua2
-rw-r--r--lualib/rspamadm/grep.lua1
-rw-r--r--lualib/rspamadm/rescore.lua193
-rw-r--r--src/rspamadm/CMakeLists.txt1
-rw-r--r--src/rspamadm/commands.c2
-rw-r--r--src/rspamadm/confighelp.c2
-rw-r--r--src/rspamadm/rescore.c191
8 files changed, 183 insertions, 239 deletions
diff --git a/lualib/rescore_utility.lua b/lualib/rescore_utility.lua
index 2a9372d4e..268e814d8 100644
--- a/lualib/rescore_utility.lua
+++ b/lualib/rescore_utility.lua
@@ -47,20 +47,30 @@ function utility.read_log_file(file)
return lines
end
-function utility.get_all_logs(dir_path)
+function utility.get_all_logs(dirs)
-- Reads all log files in the directory and returns a list of logs.
- if dir_path:sub(#dir_path, #dir_path) == "/" then
- dir_path = dir_path:sub(1, #dir_path -1)
+ if type(dirs) == 'string' then
+ dirs = {dirs}
end
- local files = rspamd_util.glob(dir_path .. "/*.log")
local all_logs = {}
- for _, file in pairs(files) do
- local logs = utility.read_log_file(file)
- for _, log_line in pairs(logs) do
- all_logs[#all_logs + 1] = log_line
+ for _,dir in ipairs(dirs) do
+ if dir:sub(-1, -1) == "/" then
+ dir = dir:sub(1, -2)
+ local files = rspamd_util.glob(dir .. "/*.log")
+ for _, file in pairs(files) do
+ local logs = utility.read_log_file(file)
+ for _, log_line in pairs(logs) do
+ table.insert(all_logs, log_line)
+ end
+ end
+ else
+ local logs = utility.read_log_file(dir)
+ for _, log_line in pairs(logs) do
+ table.insert(all_logs, log_line)
+ end
end
end
@@ -160,8 +170,8 @@ function utility.generate_statistics_from_logs(logs, threshold)
end
-- Find slowest message
- if (tonumber(log[#log-1]) > tonumber(file_stats.slowest)) then
- file_stats.slowest = tostring(tonumber(log[#log-1]))
+ if ((tonumber(log[#log-1]) or 0) > file_stats.slowest) then
+ file_stats.slowest = tonumber(log[#log-1])
file_stats.slowest_file = log[#log]
end
end
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
index 60aa2d7a9..07051a196 100644
--- a/lualib/rspamadm/corpus_test.lua
+++ b/lualib/rspamadm/corpus_test.lua
@@ -17,6 +17,7 @@ parser:option "-S --spam"
parser:option "-n --conns"
:description("Number of parallel connections")
:argname("<N>")
+ :convert(tonumber)
:default(10)
parser:option "-o --output"
:description("Output file")
@@ -25,6 +26,7 @@ parser:option "-o --output"
parser:option "-t --timeout"
:description("Timeout for client connections")
:argname("<sec>")
+ :convert(tonumber)
:default(60)
parser:option "-c --connect"
:description("Connect to specific host")
diff --git a/lualib/rspamadm/grep.lua b/lualib/rspamadm/grep.lua
index 0af83c1cf..b149a0337 100644
--- a/lualib/rspamadm/grep.lua
+++ b/lualib/rspamadm/grep.lua
@@ -38,7 +38,6 @@ parser:argument "input":args "*"
:default("stdin")
parser:flag "-S --sensitive"
:description('Enable case-sensitivity in string search')
- :default("false")
parser:flag "-o --orphans"
:description('Print orphaned logs')
parser:flag "-P --partial"
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index e6df3b364..80b9630f4 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -1,12 +1,31 @@
+--[[
+Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+if not rspamd_config:has_torch() then
+ return
+end
+
local torch = require "torch"
local nn = require "nn"
local lua_util = require "lua_util"
local ucl = require "ucl"
local logger = require "rspamd_logger"
-local getopt = require "getopt"
local optim = require "optim"
local rspamd_util = require "rspamd_util"
-
+local argparse = require "argparse"
local rescore_utility = require "rescore_utility"
local opts
@@ -15,6 +34,104 @@ local ignore_symbols = {
['DATE_IN_FUTURE'] = true,
}
+local parser = argparse()
+ :name "rspamadm rescore"
+ :description "Estimate optimal symbol weights from log files"
+ :help_description_margin(37)
+
+parser:option "-l --log"
+ :description "Log file or files (from rescore)"
+ :argname("<log>")
+ :args "*"
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<file>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-o --output"
+ :description "Output file"
+ :argname("<file>")
+ :default("new.scores")
+parser:flag "-d --diff"
+ :description "Show differences in scores"
+parser:flag "-v --verbose"
+ :description "Verbose output"
+parser:flag "-z --freq"
+ :description "Display hit frequencies"
+parser:option "-i --iters"
+ :description "Learn iterations"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-b --batch"
+ :description "Batch size"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(100)
+parser:option "-d --decay"
+ :description "Decay rate"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.001)
+parser:option "-m --momentum"
+ :description "Learn momentum"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.1)
+parser:option "-t --threads"
+ :description "Number of threads to use"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(1)
+parser:option "-o --optim"
+ :description "Optimisation algorithm"
+ :argname("<alg>")
+ :convert {
+ LBFGS = "LBFGS",
+ ADAM = "ADAM",
+ ADAGRAD = "ADAGRAD",
+ SGD = "SGD",
+ NAG = "NAG"
+ }
+ :default "ADAM"
+parser:option "--ignore-symbol"
+ :description "Ignore symbol from logs"
+ :argname("<sym>")
+ :args "*"
+parser:option "--penalty-weight"
+ :description "Add new penalty weight to test"
+ :argname("<n>")
+ :convert(tonumber)
+ :args "*"
+parser:option "--learning-rate"
+ :description "Add new learning rate to test"
+ :argname("<n>")
+ :convert(tonumber)
+ :args "*"
+parser:option "--spam_action"
+ :description "Spam action"
+ :argname("<act>")
+ :default("reject")
+parser:option "--learning_rate_decay"
+ :description "Learn rate decay (for some algs)"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+parser:option "--weight_decay"
+ :description "Weight decay (for some algs)"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+parser:option "--l1"
+ :description "L1 regularization penalty"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+parser:option "--l2"
+ :description "L2 regularization penalty"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+
local function make_dataset_from_logs(logs, all_symbols, spam_score)
-- Returns a list of {input, output} for torch SGD train
@@ -182,7 +299,8 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
logs = update_logs(logs, new_symbol_scores)
- local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local file_stats, _, all_fps, all_fns =
+ rescore_utility.generate_statistics_from_logs(logs, threshold)
return file_stats.fscore, all_fps, all_fns
end
@@ -226,7 +344,7 @@ local function train(dataset, opt, model, criterion, epoch,
local lbfgsState
local sgdState
- local batch_size = opt.batch_size
+ local batch_size = opt.batch
logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size)
@@ -300,7 +418,7 @@ local function train(dataset, opt, model, criterion, epoch,
end
-- optimize on current mini-batch
- if opt.optimization == 'LBFGS' then
+ if opt.optim == 'LBFGS' then
-- Perform LBFGS step:
lbfgsState = lbfgsState or {
@@ -315,7 +433,7 @@ local function train(dataset, opt, model, criterion, epoch,
logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter)
logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval)
- elseif opt.optimization == 'ADAM' then
+ elseif opt.optim == 'ADAM' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -323,7 +441,7 @@ local function train(dataset, opt, model, criterion, epoch,
weightDecay = tonumber(opts.weight_decay),
}
optim.adam(feval, parameters, sgdState)
- elseif opt.optimization == 'ADAGRAD' then
+ elseif opt.optim == 'ADAGRAD' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -331,7 +449,7 @@ local function train(dataset, opt, model, criterion, epoch,
weightDecay = tonumber(opts.weight_decay),
}
optim.adagrad(feval, parameters, sgdState)
- elseif opt.optimization == 'SGD' then
+ elseif opt.optim == 'SGD' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -339,7 +457,7 @@ local function train(dataset, opt, model, criterion, epoch,
weightDecay = tonumber(opts.weight_decay),
}
optim.sgd(feval, parameters, sgdState)
- elseif opt.optimization == 'NAG' then
+ elseif opt.optim == 'NAG' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -348,7 +466,8 @@ local function train(dataset, opt, model, criterion, epoch,
}
optim.nag(feval, parameters, sgdState)
else
- error('unknown optimization method')
+ logger.errx('unknown optimization method: %s', opt.optim)
+ os.exit(1)
end
end
@@ -363,19 +482,6 @@ local function train(dataset, opt, model, criterion, epoch,
confusion:zero()
end
-
-local default_opts = {
- verbose = true,
- iters = 10,
- threads = 1,
- batch_size = 1000,
- optimization = 'ADAM',
- learning_rate_decay = 0.001,
- momentum = 0.1,
- l1 = 0.0,
- l2 = 0.0,
-}
-
local learning_rates = {
0.01
}
@@ -393,11 +499,27 @@ local function get_threshold()
or actions['reject']), actions['reject']
end
-return function (args, cfg)
- opts = default_opts
- opts = lua_util.override_defaults(opts, getopt.getopt(args, 'i:'))
+local function handler(args)
+ opts = parser:parse(args)
+ if not opts['log'] then
+ parser:error('no log specified')
+ end
+
+ local _r,err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r,err = rspamd_config:parse_rcl({'logging', 'worker'})
+ if not _r then
+ logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
local threshold,reject_score = get_threshold()
- local logs = rescore_utility.get_all_logs(cfg["logdir"])
+ local logs = rescore_utility.get_all_logs(opts['log'])
if opts['ignore-symbol'] then
local function add_ignore(s)
@@ -446,14 +568,12 @@ return function (args, cfg)
end
end
- if opts['i'] then opts['iters'] = opts['i'] end
-
local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
ignore_symbols)
-- Display hit frequencies
- if opts['z'] then
+ if opts['freq'] then
local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
local t = {}
for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
@@ -580,11 +700,11 @@ return function (args, cfg)
new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
- if cfg["output"] then
- write_scores(new_symbol_scores, cfg["output"])
+ if opts["output"] then
+ write_scores(new_symbol_scores, opts["output"])
end
- if cfg["diff"] then
+ if opts["diff"] then
print_score_diff(new_symbol_scores, original_symbol_scores)
end
@@ -616,3 +736,10 @@ return function (args, cfg)
end
end
end
+
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'rescore'
+} \ No newline at end of file
diff --git a/src/rspamadm/CMakeLists.txt b/src/rspamadm/CMakeLists.txt
index 616c40322..3559186f4 100644
--- a/src/rspamadm/CMakeLists.txt
+++ b/src/rspamadm/CMakeLists.txt
@@ -13,7 +13,6 @@ SET(RSPAMADMSRC rspamadm.c
signtool.c
lua_repl.c
dkim_keygen.c
- rescore.c
${CMAKE_BINARY_DIR}/src/workers.c
${CMAKE_BINARY_DIR}/src/modules.c
${CMAKE_SOURCE_DIR}/src/controller.c
diff --git a/src/rspamadm/commands.c b/src/rspamadm/commands.c
index 306938697..224cc48d3 100644
--- a/src/rspamadm/commands.c
+++ b/src/rspamadm/commands.c
@@ -31,7 +31,6 @@ extern struct rspamadm_command signtool_command;
extern struct rspamadm_command lua_command;
extern struct rspamadm_command dkim_keygen_command;
extern struct rspamadm_command configwizard_command;
-extern struct rspamadm_command rescore_command;
const struct rspamadm_command *commands[] = {
&help_command,
@@ -48,7 +47,6 @@ const struct rspamadm_command *commands[] = {
&lua_command,
&dkim_keygen_command,
&configwizard_command,
- &rescore_command,
NULL
};
diff --git a/src/rspamadm/confighelp.c b/src/rspamadm/confighelp.c
index 7ce72dea9..08789a468 100644
--- a/src/rspamadm/confighelp.c
+++ b/src/rspamadm/confighelp.c
@@ -231,7 +231,7 @@ rspamadm_confighelp (gint argc, gchar **argv, const struct rspamadm_command *cmd
cfg->compiled_modules = modules;
cfg->compiled_workers = workers;
- rspamd_rcl_config_init (cfg);
+ rspamd_rcl_config_init (cfg, NULL);
lua_pushboolean (cfg->lua_state, true);
lua_setglobal (cfg->lua_state, "confighelp");
rspamd_rcl_add_lua_plugins_path (cfg, plugins_path, NULL);
diff --git a/src/rspamadm/rescore.c b/src/rspamadm/rescore.c
deleted file mode 100644
index 346c49c68..000000000
--- a/src/rspamadm/rescore.c
+++ /dev/null
@@ -1,191 +0,0 @@
-/*-
- * Copyright 2017 Pragadeesh C
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "config.h"
-#include "rspamadm.h"
-#include "lua/lua_common.h"
-
-static gchar *logdir = NULL;
-static gchar *output = "new.scores";
-static gboolean score_diff = false; /* Print score diff flag */
-static gchar *config = NULL;
-extern struct rspamd_main *rspamd_main;
-/* Defined in modules.c */
-extern module_t *modules[];
-extern worker_t *workers[];
-
-static void rspamadm_rescore (gint argc, gchar **argv,
- const struct rspamadm_command *cmd);
-
-static const char *rspamadm_rescore_help (gboolean full_help,
- const struct rspamadm_command *cmd);
-
-struct rspamadm_command rescore_command = {
- .name = "rescore",
- .flags = 0,
- .help = rspamadm_rescore_help,
- .run = rspamadm_rescore
-};
-
-static GOptionEntry entries[] = {
- {"logdir", 'l', 0, G_OPTION_ARG_FILENAME, &logdir,
- "Logs directory", NULL},
- {"output", 'o', 0, G_OPTION_ARG_FILENAME, &output,
- "Scores output location", NULL},
- {"diff", 'd', 0, G_OPTION_ARG_NONE, &score_diff,
- "Print score diff", NULL},
- {"config", 'c', 0, G_OPTION_ARG_STRING, &config,
- "Config file to use", NULL},
- {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL}
-};
-
-static void
-config_logger (rspamd_mempool_t *pool, gpointer ud)
-{
- struct rspamd_main *rm = ud;
-
- rm->cfg->log_type = RSPAMD_LOG_CONSOLE;
- rm->cfg->log_level = G_LOG_LEVEL_MESSAGE;
-
- rspamd_set_logger (rm->cfg, g_quark_try_string ("main"), &rm->logger,
- rm->server_pool);
-
- if (rspamd_log_open_priv (rm->logger, rm->workers_uid, rm->workers_gid) ==
- -1) {
- fprintf (stderr, "Fatal error, cannot open logfile, exiting\n");
- exit (EXIT_FAILURE);
- }
-}
-
-static const char *
-rspamadm_rescore_help (gboolean full_help, const struct rspamadm_command *cmd)
-{
- const char *help_str;
-
- if (full_help) {
- help_str = "Estimate optimal symbol weights from log files\n\n"
- "Usage: rspamadm rescore -l <log_directory>\n"
- "Where options are:\n\n"
- "-l: path to logs directory\n"
- "-o: scores output file location\n"
- "-d: print scores diff\n"
- "-i: max iterations for perceptron\n";
- } else {
- help_str = "Estimate optimal symbol weights from log files";
- }
-
- return help_str;
-}
-
-static void
-rspamadm_rescore (gint argc, gchar **argv, const struct rspamadm_command *cmd)
-{
- GOptionContext *context;
- GError *error = NULL;
- struct rspamd_config *cfg = rspamd_main->cfg, **pcfg;
- gboolean ret = TRUE;
- worker_t **pworker;
- const gchar *confdir;
-
-#ifndef WITH_TORCH
- rspamd_fprintf (stderr, "Torch is not enabled. "
- "Use -DENABLE_TORCH=ON option while running cmake.\n");
- exit (EXIT_FAILURE);
-#endif
-
- context = g_option_context_new (
- "rescore - estimate optimal symbol weights from log files");
-
- g_option_context_set_summary (context,
- "Summary:\n Rspamd administration utility version "
- RVERSION
- "\n Release id: "
- RID);
-
- g_option_context_add_main_entries (context, entries, NULL);
- g_option_context_set_ignore_unknown_options (context, TRUE);
-
- if (!g_option_context_parse (context, &argc, &argv, &error)) {
- rspamd_fprintf (stderr, "option parsing failed: %s\n", error->message);
- g_error_free (error);
- exit (EXIT_FAILURE);
- }
-
- if (logdir == NULL) {
- rspamd_fprintf (stderr, "Please specify log directory.\n");
- exit (EXIT_FAILURE);
- }
-
- if (config == NULL) {
- if ((confdir = g_hash_table_lookup (ucl_vars, "CONFDIR")) == NULL) {
- confdir = RSPAMD_CONFDIR;
- }
-
- config = g_strdup_printf ("%s%c%s", confdir, G_DIR_SEPARATOR,
- "rspamd.conf");
- }
-
- pworker = &workers[0];
- while (*pworker) {
- /* Init string quarks */
- (void) g_quark_from_static_string ((*pworker)->name);
- pworker++;
- }
-
- cfg->cache = rspamd_symbols_cache_new (cfg);
- cfg->compiled_modules = modules;
- cfg->compiled_workers = workers;
- cfg->cfg_name = config;
-
- if (!rspamd_config_read (cfg, cfg->cfg_name, config_logger, rspamd_main, ucl_vars)) {
- ret = FALSE;
- }
- else {
- /* Do post-load actions */
- rspamd_lua_post_load_config (cfg);
-
- if (!rspamd_init_filters (cfg, FALSE)) {
- ret = FALSE;
- }
-
- if (ret) {
- ret = rspamd_config_post_load (cfg, RSPAMD_CONFIG_INIT_SYMCACHE);
- rspamd_symbols_cache_validate (cfg->cache,
- cfg,
- FALSE);
- }
- }
-
- if (ret) {
- ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (cfg->cfg_name),
- "config_path", 0, false);
- ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (logdir),
- "logdir", 0, false);
- ucl_object_insert_key (cfg->rcl_obj, ucl_object_fromstring (output),
- "output", 0, false);
- ucl_object_insert_key (cfg->rcl_obj, ucl_object_frombool (score_diff),
- "diff", 0, false);
- pcfg = lua_newuserdata (L, sizeof (struct rspamd_config *));
- rspamd_lua_setclass (L, "rspamd{config}", -1);
- *pcfg = cfg;
- lua_setglobal (L, "rspamd_config");
- rspamadm_execute_lua_ucl_subr (L,
- argc,
- argv,
- cfg->rcl_obj,
- "rescore");
- }
-}