diff options
Diffstat (limited to 'lualib/rspamadm/rescore.lua')
-rw-r--r-- | lualib/rspamadm/rescore.lua | 193 |
1 files changed, 160 insertions, 33 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index e6df3b364..80b9630f4 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -1,12 +1,31 @@ +--[[ +Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if not rspamd_config:has_torch() then + return +end + local torch = require "torch" local nn = require "nn" local lua_util = require "lua_util" local ucl = require "ucl" local logger = require "rspamd_logger" -local getopt = require "getopt" local optim = require "optim" local rspamd_util = require "rspamd_util" - +local argparse = require "argparse" local rescore_utility = require "rescore_utility" local opts @@ -15,6 +34,104 @@ local ignore_symbols = { ['DATE_IN_FUTURE'] = true, } +local parser = argparse() + :name "rspamadm rescore" + :description "Estimate optimal symbol weights from log files" + :help_description_margin(37) + +parser:option "-l --log" + :description "Log file or files (from rescore)" + :argname("<log>") + :args "*" +parser:option "-c --config" + :description "Path to config file" + :argname("<file>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:option "-o --output" + :description "Output file" + :argname("<file>") + :default("new.scores") +parser:flag "-d --diff" + :description "Show differences in scores" +parser:flag "-v --verbose" + :description "Verbose output" +parser:flag "-z --freq" + :description "Display hit frequencies" +parser:option "-i --iters" + :description "Learn iterations" + :argname("<n>") + :convert(tonumber) + :default(10) +parser:option "-b --batch" + :description "Batch size" + :argname("<n>") + :convert(tonumber) + :default(100) +parser:option "-d --decay" + :description "Decay rate" + :argname("<n>") + :convert(tonumber) + :default(0.001) +parser:option "-m --momentum" + :description "Learn momentum" + :argname("<n>") + :convert(tonumber) + :default(0.1) +parser:option "-t --threads" + :description "Number of threads to use" + :argname("<n>") + :convert(tonumber) + :default(1) +parser:option "-o --optim" + :description "Optimisation algorithm" + :argname("<alg>") + :convert { + LBFGS = "LBFGS", + ADAM = "ADAM", + ADAGRAD = "ADAGRAD", + SGD = "SGD", + NAG = "NAG" + } + :default "ADAM" +parser:option "--ignore-symbol" + :description "Ignore symbol from logs" + :argname("<sym>") + :args "*" +parser:option "--penalty-weight" + :description "Add new penalty weight to test" + :argname("<n>") + :convert(tonumber) + :args "*" +parser:option "--learning-rate" + :description "Add new learning rate to test" + :argname("<n>") + :convert(tonumber) + :args "*" +parser:option "--spam_action" + :description "Spam action" + :argname("<act>") + :default("reject") +parser:option "--learning_rate_decay" + :description "Learn rate decay (for some algs)" + :argname("<n>") + :convert(tonumber) + :default(0.0) +parser:option "--weight_decay" + :description "Weight decay (for some algs)" + :argname("<n>") + :convert(tonumber) + :default(0.0) +parser:option "--l1" + :description "L1 regularization penalty" + :argname("<n>") + :convert(tonumber) + :default(0.0) +parser:option "--l2" + :description "L2 regularization penalty" + :argname("<n>") + :convert(tonumber) + :default(0.0) + local function make_dataset_from_logs(logs, all_symbols, spam_score) -- Returns a list of {input, output} for torch SGD train @@ -182,7 +299,8 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho logs = update_logs(logs, new_symbol_scores) - local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold) + local file_stats, _, all_fps, all_fns = + rescore_utility.generate_statistics_from_logs(logs, threshold) return file_stats.fscore, all_fps, all_fns end @@ -226,7 +344,7 @@ local function train(dataset, opt, model, criterion, epoch, local lbfgsState local sgdState - local batch_size = opt.batch_size + local batch_size = opt.batch logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size) @@ -300,7 +418,7 @@ local function train(dataset, opt, model, criterion, epoch, end -- optimize on current mini-batch - if opt.optimization == 'LBFGS' then + if opt.optim == 'LBFGS' then -- Perform LBFGS step: lbfgsState = lbfgsState or { @@ -315,7 +433,7 @@ local function train(dataset, opt, model, criterion, epoch, logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter) logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval) - elseif opt.optimization == 'ADAM' then + elseif opt.optim == 'ADAM' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -323,7 +441,7 @@ local function train(dataset, opt, model, criterion, epoch, weightDecay = tonumber(opts.weight_decay), } optim.adam(feval, parameters, sgdState) - elseif opt.optimization == 'ADAGRAD' then + elseif opt.optim == 'ADAGRAD' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -331,7 +449,7 @@ local function train(dataset, opt, model, criterion, epoch, weightDecay = tonumber(opts.weight_decay), } optim.adagrad(feval, parameters, sgdState) - elseif opt.optimization == 'SGD' then + elseif opt.optim == 'SGD' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -339,7 +457,7 @@ local function train(dataset, opt, model, criterion, epoch, weightDecay = tonumber(opts.weight_decay), } optim.sgd(feval, parameters, sgdState) - elseif opt.optimization == 'NAG' then + elseif opt.optim == 'NAG' then sgdState = sgdState or { learningRate = tonumber(opts.learning_rate),-- opt.learningRate, momentum = tonumber(opts.momentum), -- opt.momentum, @@ -348,7 +466,8 @@ local function train(dataset, opt, model, criterion, epoch, } optim.nag(feval, parameters, sgdState) else - error('unknown optimization method') + logger.errx('unknown optimization method: %s', opt.optim) + os.exit(1) end end @@ -363,19 +482,6 @@ local function train(dataset, opt, model, criterion, epoch, confusion:zero() end - -local default_opts = { - verbose = true, - iters = 10, - threads = 1, - batch_size = 1000, - optimization = 'ADAM', - learning_rate_decay = 0.001, - momentum = 0.1, - l1 = 0.0, - l2 = 0.0, -} - local learning_rates = { 0.01 } @@ -393,11 +499,27 @@ local function get_threshold() or actions['reject']), actions['reject'] end -return function (args, cfg) - opts = default_opts - opts = lua_util.override_defaults(opts, getopt.getopt(args, 'i:')) +local function handler(args) + opts = parser:parse(args) + if not opts['log'] then + parser:error('no log specified') + end + + local _r,err = rspamd_config:load_ucl(opts['config']) + + if not _r then + logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r,err = rspamd_config:parse_rcl({'logging', 'worker'}) + if not _r then + logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end + local threshold,reject_score = get_threshold() - local logs = rescore_utility.get_all_logs(cfg["logdir"]) + local logs = rescore_utility.get_all_logs(opts['log']) if opts['ignore-symbol'] then local function add_ignore(s) @@ -446,14 +568,12 @@ return function (args, cfg) end end - if opts['i'] then opts['iters'] = opts['i'] end - local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols) local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config, ignore_symbols) -- Display hit frequencies - if opts['z'] then + if opts['freq'] then local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold) local t = {} for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end @@ -580,11 +700,11 @@ return function (args, cfg) new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores) - if cfg["output"] then - write_scores(new_symbol_scores, cfg["output"]) + if opts["output"] then + write_scores(new_symbol_scores, opts["output"]) end - if cfg["diff"] then + if opts["diff"] then print_score_diff(new_symbol_scores, original_symbol_scores) end @@ -616,3 +736,10 @@ return function (args, cfg) end end end + + +return { + handler = handler, + description = parser._description, + name = 'rescore' +}
\ No newline at end of file |