aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-03 13:36:07 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-03-03 13:36:07 +0000
commit03a5515ef88bbb0bcc26580ad94424c922af6ed7 (patch)
treecbb2eae5a77fdc2b872f7adb57f1e20509a12edd
parent22f1d4a067ccfbcdc6bcc83e6a6692da5e774c4e (diff)
downloadrspamd-03a5515ef88bbb0bcc26580ad94424c922af6ed7.tar.gz
rspamd-03a5515ef88bbb0bcc26580ad94424c922af6ed7.zip
[Feature] Add more features to rescore utility
- Allow to ignore specific symbols - Allow to specify learning rates and weight penalty manually
-rw-r--r--lualib/rspamadm/rescore.lua82
-rw-r--r--lualib/rspamadm/rescore_utility.lua12
2 files changed, 78 insertions, 16 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index d76dc3861..6dcecc44b 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -8,6 +8,10 @@ local getopt = require "rspamadm/getopt"
local rescore_utility = require "rspamadm/rescore_utility"
local opts
+local ignore_symbols = {
+ ['DATE_IN_PAST'] =true,
+ ['DATE_IN_FUTURE'] = true,
+}
local function make_dataset_from_logs(logs, all_symbols)
-- Returns a list of {input, output} for torch SGD train
@@ -28,7 +32,9 @@ local function make_dataset_from_logs(logs, all_symbols)
local symbols_set = {}
for i=4,#log do
- symbols_set[log[i]] = true
+ if not ignore_symbols[log[i]] then
+ symbols_set[log[i]] = true
+ end
end
for index, symbol in pairs(all_symbols) do
@@ -209,6 +215,13 @@ local default_opts = {
threads = 1,
}
+local learning_rates = {
+ 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
+}
+local penalty_weights = {
+ 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
+}
+
local function override_defaults(def, override)
for k,v in pairs(override) do
if def[k] then
@@ -235,11 +248,63 @@ end
return function (args, cfg)
opts = default_opts
- override_defaults(opts, getopt.getopt(args, ''))
+ override_defaults(opts, getopt.getopt(args, 'i:'))
local threshold = get_threshold()
local logs = rescore_utility.get_all_logs(cfg["logdir"])
- local all_symbols = rescore_utility.get_all_symbols(logs)
- local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config)
+
+ if opts['ignore-symbol'] then
+ local function add_ignore(s)
+ ignore_symbols[s] = true
+ end
+ if type(opts['ignore-symbol']) == 'table' then
+ for _,s in ipairs(opts['ignore-symbol']) do
+ add_ignore(s)
+ end
+ else
+ add_ignore(opts['ignore-symbol'])
+ end
+ end
+
+ if opts['learning-rate'] then
+ learning_rates = {}
+
+ local function add_rate(r)
+ if tonumber(r) then
+ table.insert(learning_rates, tonumber(r))
+ end
+ end
+ if type(opts['learning-rate']) == 'table' then
+ for _,s in ipairs(opts['learning-rate']) do
+ add_rate(s)
+ end
+ else
+ add_rate(opts['learning-rate'])
+ end
+ end
+
+ if opts['penalty-weight'] then
+ penalty_weights = {}
+
+ local function add_weight(r)
+ if tonumber(r) then
+ table.insert(penalty_weights, tonumber(r))
+ end
+ end
+ if type(opts['penalty-weight']) == 'table' then
+ for _,s in ipairs(opts['penalty-weight']) do
+ add_weight(s)
+ end
+ else
+ add_weight(opts['penalty-weight'])
+ end
+ end
+
+ if opts['i'] then opts['iters'] = opts['i'] end
+ logger.errx('%s', opts)
+
+ local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
+ local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
+ ignore_symbols)
shuffle(logs)
torch.setdefaulttensortype('torch.FloatTensor')
@@ -249,15 +314,8 @@ return function (args, cfg)
local dataset = make_dataset_from_logs(train_logs, all_symbols)
- local learning_rates = {
- 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
- }
- local penalty_weights = {
- 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
- }
-- Start of perceptron training
-
local input_size = #all_symbols
torch.setnumthreads(opts['threads'])
local linear_module = nn.Linear(input_size, 1)
@@ -276,7 +334,7 @@ return function (args, cfg)
local best_weights = linear_module.weight[1]:clone()
local trainer = nn.StochasticGradient(perceptron, criterion)
- trainer.maxIteration = opts["iters"]
+ trainer.maxIteration = tonumber(opts["iters"])
trainer.verbose = opts['verbose']
trainer.hookIteration = function(self, iteration, error)
diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua
index db79bbf7b..7f3f40078 100644
--- a/lualib/rspamadm/rescore_utility.lua
+++ b/lualib/rspamadm/rescore_utility.lua
@@ -4,7 +4,7 @@ local fun = require "fun"
local utility = {}
-function utility.get_all_symbols(logs)
+function utility.get_all_symbols(logs, ignore_symbols)
-- Returns a list of all symbols
local symbols_set = {}
@@ -22,7 +22,9 @@ function utility.get_all_symbols(logs)
local all_symbols = {}
for symbol, _ in pairs(symbols_set) do
- all_symbols[#all_symbols + 1] = symbol
+ if not ignore_symbols[symbol] then
+ all_symbols[#all_symbols + 1] = symbol
+ end
end
table.sort(all_symbols)
@@ -65,12 +67,14 @@ function utility.get_all_logs(dir_path)
return all_logs
end
-function utility.get_all_symbol_scores(conf)
+function utility.get_all_symbol_scores(conf, ignore_symbols)
local counters = conf:get_symbols_counters()
return fun.tomap(fun.map(function(elt)
return elt['symbol'],elt['weight']
- end, counters))
+ end, fun.filter(function(elt)
+ return not ignore_symbols[elt['symbol']]
+ end, counters)))
end
function utility.generate_statistics_from_logs(logs, threshold)