From 03a5515ef88bbb0bcc26580ad94424c922af6ed7 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 3 Mar 2018 13:36:07 +0000 Subject: [PATCH] [Feature] Add more features to rescore utility - Allow to ignore specific symbols - Allow to specify learning rates and weight penalty manually --- lualib/rspamadm/rescore.lua | 82 ++++++++++++++++++++++++----- lualib/rspamadm/rescore_utility.lua | 12 +++-- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index d76dc3861..6dcecc44b 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -8,6 +8,10 @@ local getopt = require "rspamadm/getopt" local rescore_utility = require "rspamadm/rescore_utility" local opts +local ignore_symbols = { + ['DATE_IN_PAST'] =true, + ['DATE_IN_FUTURE'] = true, +} local function make_dataset_from_logs(logs, all_symbols) -- Returns a list of {input, output} for torch SGD train @@ -28,7 +32,9 @@ local function make_dataset_from_logs(logs, all_symbols) local symbols_set = {} for i=4,#log do - symbols_set[log[i]] = true + if not ignore_symbols[log[i]] then + symbols_set[log[i]] = true + end end for index, symbol in pairs(all_symbols) do @@ -209,6 +215,13 @@ local default_opts = { threads = 1, } +local learning_rates = { + 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10 +} +local penalty_weights = { + 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100 +} + local function override_defaults(def, override) for k,v in pairs(override) do if def[k] then @@ -235,11 +248,63 @@ end return function (args, cfg) opts = default_opts - override_defaults(opts, getopt.getopt(args, '')) + override_defaults(opts, getopt.getopt(args, 'i:')) local threshold = get_threshold() local logs = rescore_utility.get_all_logs(cfg["logdir"]) - local all_symbols = rescore_utility.get_all_symbols(logs) - local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config) + + if opts['ignore-symbol'] then + local function add_ignore(s) + ignore_symbols[s] = true + end + if type(opts['ignore-symbol']) == 'table' then + for _,s in ipairs(opts['ignore-symbol']) do + add_ignore(s) + end + else + add_ignore(opts['ignore-symbol']) + end + end + + if opts['learning-rate'] then + learning_rates = {} + + local function add_rate(r) + if tonumber(r) then + table.insert(learning_rates, tonumber(r)) + end + end + if type(opts['learning-rate']) == 'table' then + for _,s in ipairs(opts['learning-rate']) do + add_rate(s) + end + else + add_rate(opts['learning-rate']) + end + end + + if opts['penalty-weight'] then + penalty_weights = {} + + local function add_weight(r) + if tonumber(r) then + table.insert(penalty_weights, tonumber(r)) + end + end + if type(opts['penalty-weight']) == 'table' then + for _,s in ipairs(opts['penalty-weight']) do + add_weight(s) + end + else + add_weight(opts['penalty-weight']) + end + end + + if opts['i'] then opts['iters'] = opts['i'] end + logger.errx('%s', opts) + + local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols) + local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config, + ignore_symbols) shuffle(logs) torch.setdefaulttensortype('torch.FloatTensor') @@ -249,15 +314,8 @@ return function (args, cfg) local dataset = make_dataset_from_logs(train_logs, all_symbols) - local learning_rates = { - 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10 - } - local penalty_weights = { - 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100 - } -- Start of perceptron training - local input_size = #all_symbols torch.setnumthreads(opts['threads']) local linear_module = nn.Linear(input_size, 1) @@ -276,7 +334,7 @@ return function (args, cfg) local best_weights = linear_module.weight[1]:clone() local trainer = nn.StochasticGradient(perceptron, criterion) - trainer.maxIteration = opts["iters"] + trainer.maxIteration = tonumber(opts["iters"]) trainer.verbose = opts['verbose'] trainer.hookIteration = function(self, iteration, error) diff --git a/lualib/rspamadm/rescore_utility.lua b/lualib/rspamadm/rescore_utility.lua index db79bbf7b..7f3f40078 100644 --- a/lualib/rspamadm/rescore_utility.lua +++ b/lualib/rspamadm/rescore_utility.lua @@ -4,7 +4,7 @@ local fun = require "fun" local utility = {} -function utility.get_all_symbols(logs) +function utility.get_all_symbols(logs, ignore_symbols) -- Returns a list of all symbols local symbols_set = {} @@ -22,7 +22,9 @@ function utility.get_all_symbols(logs) local all_symbols = {} for symbol, _ in pairs(symbols_set) do - all_symbols[#all_symbols + 1] = symbol + if not ignore_symbols[symbol] then + all_symbols[#all_symbols + 1] = symbol + end end table.sort(all_symbols) @@ -65,12 +67,14 @@ function utility.get_all_logs(dir_path) return all_logs end -function utility.get_all_symbol_scores(conf) +function utility.get_all_symbol_scores(conf, ignore_symbols) local counters = conf:get_symbols_counters() return fun.tomap(fun.map(function(elt) return elt['symbol'],elt['weight'] - end, counters)) + end, fun.filter(function(elt) + return not ignore_symbols[elt['symbol']] + end, counters))) end function utility.generate_statistics_from_logs(logs, threshold) -- 2.39.5