summaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm/rescore.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/rspamadm/rescore.lua')
-rw-r--r--lualib/rspamadm/rescore.lua193
1 files changed, 160 insertions, 33 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index e6df3b364..80b9630f4 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -1,12 +1,31 @@
+--[[
+Copyright (c) 2018, Vsevolod Stakhov <vsevolod@highsecure.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+if not rspamd_config:has_torch() then
+ return
+end
+
local torch = require "torch"
local nn = require "nn"
local lua_util = require "lua_util"
local ucl = require "ucl"
local logger = require "rspamd_logger"
-local getopt = require "getopt"
local optim = require "optim"
local rspamd_util = require "rspamd_util"
-
+local argparse = require "argparse"
local rescore_utility = require "rescore_utility"
local opts
@@ -15,6 +34,104 @@ local ignore_symbols = {
['DATE_IN_FUTURE'] = true,
}
+local parser = argparse()
+ :name "rspamadm rescore"
+ :description "Estimate optimal symbol weights from log files"
+ :help_description_margin(37)
+
+parser:option "-l --log"
+ :description "Log file or files (from rescore)"
+ :argname("<log>")
+ :args "*"
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<file>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-o --output"
+ :description "Output file"
+ :argname("<file>")
+ :default("new.scores")
+parser:flag "-d --diff"
+ :description "Show differences in scores"
+parser:flag "-v --verbose"
+ :description "Verbose output"
+parser:flag "-z --freq"
+ :description "Display hit frequencies"
+parser:option "-i --iters"
+ :description "Learn iterations"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-b --batch"
+ :description "Batch size"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(100)
+parser:option "-d --decay"
+ :description "Decay rate"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.001)
+parser:option "-m --momentum"
+ :description "Learn momentum"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.1)
+parser:option "-t --threads"
+ :description "Number of threads to use"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(1)
+parser:option "-o --optim"
+ :description "Optimisation algorithm"
+ :argname("<alg>")
+ :convert {
+ LBFGS = "LBFGS",
+ ADAM = "ADAM",
+ ADAGRAD = "ADAGRAD",
+ SGD = "SGD",
+ NAG = "NAG"
+ }
+ :default "ADAM"
+parser:option "--ignore-symbol"
+ :description "Ignore symbol from logs"
+ :argname("<sym>")
+ :args "*"
+parser:option "--penalty-weight"
+ :description "Add new penalty weight to test"
+ :argname("<n>")
+ :convert(tonumber)
+ :args "*"
+parser:option "--learning-rate"
+ :description "Add new learning rate to test"
+ :argname("<n>")
+ :convert(tonumber)
+ :args "*"
+parser:option "--spam_action"
+ :description "Spam action"
+ :argname("<act>")
+ :default("reject")
+parser:option "--learning_rate_decay"
+ :description "Learn rate decay (for some algs)"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+parser:option "--weight_decay"
+ :description "Weight decay (for some algs)"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+parser:option "--l1"
+ :description "L1 regularization penalty"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+parser:option "--l2"
+ :description "L2 regularization penalty"
+ :argname("<n>")
+ :convert(tonumber)
+ :default(0.0)
+
local function make_dataset_from_logs(logs, all_symbols, spam_score)
-- Returns a list of {input, output} for torch SGD train
@@ -182,7 +299,8 @@ local function calculate_fscore_from_weights(logs, all_symbols, weights, thresho
logs = update_logs(logs, new_symbol_scores)
- local file_stats, _, all_fps, all_fns = rescore_utility.generate_statistics_from_logs(logs, threshold)
+ local file_stats, _, all_fps, all_fns =
+ rescore_utility.generate_statistics_from_logs(logs, threshold)
return file_stats.fscore, all_fps, all_fns
end
@@ -226,7 +344,7 @@ local function train(dataset, opt, model, criterion, epoch,
local lbfgsState
local sgdState
- local batch_size = opt.batch_size
+ local batch_size = opt.batch
logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size)
@@ -300,7 +418,7 @@ local function train(dataset, opt, model, criterion, epoch,
end
-- optimize on current mini-batch
- if opt.optimization == 'LBFGS' then
+ if opt.optim == 'LBFGS' then
-- Perform LBFGS step:
lbfgsState = lbfgsState or {
@@ -315,7 +433,7 @@ local function train(dataset, opt, model, criterion, epoch,
logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter)
logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval)
- elseif opt.optimization == 'ADAM' then
+ elseif opt.optim == 'ADAM' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -323,7 +441,7 @@ local function train(dataset, opt, model, criterion, epoch,
weightDecay = tonumber(opts.weight_decay),
}
optim.adam(feval, parameters, sgdState)
- elseif opt.optimization == 'ADAGRAD' then
+ elseif opt.optim == 'ADAGRAD' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -331,7 +449,7 @@ local function train(dataset, opt, model, criterion, epoch,
weightDecay = tonumber(opts.weight_decay),
}
optim.adagrad(feval, parameters, sgdState)
- elseif opt.optimization == 'SGD' then
+ elseif opt.optim == 'SGD' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -339,7 +457,7 @@ local function train(dataset, opt, model, criterion, epoch,
weightDecay = tonumber(opts.weight_decay),
}
optim.sgd(feval, parameters, sgdState)
- elseif opt.optimization == 'NAG' then
+ elseif opt.optim == 'NAG' then
sgdState = sgdState or {
learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
momentum = tonumber(opts.momentum), -- opt.momentum,
@@ -348,7 +466,8 @@ local function train(dataset, opt, model, criterion, epoch,
}
optim.nag(feval, parameters, sgdState)
else
- error('unknown optimization method')
+ logger.errx('unknown optimization method: %s', opt.optim)
+ os.exit(1)
end
end
@@ -363,19 +482,6 @@ local function train(dataset, opt, model, criterion, epoch,
confusion:zero()
end
-
-local default_opts = {
- verbose = true,
- iters = 10,
- threads = 1,
- batch_size = 1000,
- optimization = 'ADAM',
- learning_rate_decay = 0.001,
- momentum = 0.1,
- l1 = 0.0,
- l2 = 0.0,
-}
-
local learning_rates = {
0.01
}
@@ -393,11 +499,27 @@ local function get_threshold()
or actions['reject']), actions['reject']
end
-return function (args, cfg)
- opts = default_opts
- opts = lua_util.override_defaults(opts, getopt.getopt(args, 'i:'))
+local function handler(args)
+ opts = parser:parse(args)
+ if not opts['log'] then
+ parser:error('no log specified')
+ end
+
+ local _r,err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r,err = rspamd_config:parse_rcl({'logging', 'worker'})
+ if not _r then
+ logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
local threshold,reject_score = get_threshold()
- local logs = rescore_utility.get_all_logs(cfg["logdir"])
+ local logs = rescore_utility.get_all_logs(opts['log'])
if opts['ignore-symbol'] then
local function add_ignore(s)
@@ -446,14 +568,12 @@ return function (args, cfg)
end
end
- if opts['i'] then opts['iters'] = opts['i'] end
-
local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
ignore_symbols)
-- Display hit frequencies
- if opts['z'] then
+ if opts['freq'] then
local _, all_symbols_stats = rescore_utility.generate_statistics_from_logs(logs, threshold)
local t = {}
for _, symbol_stats in pairs(all_symbols_stats) do table.insert(t, symbol_stats) end
@@ -580,11 +700,11 @@ return function (args, cfg)
new_symbol_scores = stitch_new_scores(all_symbols, new_symbol_scores)
- if cfg["output"] then
- write_scores(new_symbol_scores, cfg["output"])
+ if opts["output"] then
+ write_scores(new_symbol_scores, opts["output"])
end
- if cfg["diff"] then
+ if opts["diff"] then
print_score_diff(new_symbol_scores, original_symbol_scores)
end
@@ -616,3 +736,10 @@ return function (args, cfg)
end
end
end
+
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'rescore'
+} \ No newline at end of file