local rescore_utility = require "rspamadm/rescore_utility"
local opts
+local ignore_symbols = {
+ ['DATE_IN_PAST'] =true,
+ ['DATE_IN_FUTURE'] = true,
+}
local function make_dataset_from_logs(logs, all_symbols)
-- Returns a list of {input, output} for torch SGD train
local symbols_set = {}
for i=4,#log do
- symbols_set[log[i]] = true
+ if not ignore_symbols[log[i]] then
+ symbols_set[log[i]] = true
+ end
end
for index, symbol in pairs(all_symbols) do
threads = 1,
}
+local learning_rates = {
+ 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
+}
+local penalty_weights = {
+ 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
+}
+
local function override_defaults(def, override)
for k,v in pairs(override) do
if def[k] then
return function (args, cfg)
opts = default_opts
- override_defaults(opts, getopt.getopt(args, ''))
+ override_defaults(opts, getopt.getopt(args, 'i:'))
local threshold = get_threshold()
local logs = rescore_utility.get_all_logs(cfg["logdir"])
- local all_symbols = rescore_utility.get_all_symbols(logs)
- local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config)
+
+ if opts['ignore-symbol'] then
+ local function add_ignore(s)
+ ignore_symbols[s] = true
+ end
+ if type(opts['ignore-symbol']) == 'table' then
+ for _,s in ipairs(opts['ignore-symbol']) do
+ add_ignore(s)
+ end
+ else
+ add_ignore(opts['ignore-symbol'])
+ end
+ end
+
+ if opts['learning-rate'] then
+ learning_rates = {}
+
+ local function add_rate(r)
+ if tonumber(r) then
+ table.insert(learning_rates, tonumber(r))
+ end
+ end
+ if type(opts['learning-rate']) == 'table' then
+ for _,s in ipairs(opts['learning-rate']) do
+ add_rate(s)
+ end
+ else
+ add_rate(opts['learning-rate'])
+ end
+ end
+
+ if opts['penalty-weight'] then
+ penalty_weights = {}
+
+ local function add_weight(r)
+ if tonumber(r) then
+ table.insert(penalty_weights, tonumber(r))
+ end
+ end
+ if type(opts['penalty-weight']) == 'table' then
+ for _,s in ipairs(opts['penalty-weight']) do
+ add_weight(s)
+ end
+ else
+ add_weight(opts['penalty-weight'])
+ end
+ end
+
+ if opts['i'] then opts['iters'] = opts['i'] end
+ logger.errx('%s', opts)
+
+ local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
+ local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
+ ignore_symbols)
shuffle(logs)
torch.setdefaulttensortype('torch.FloatTensor')
local dataset = make_dataset_from_logs(train_logs, all_symbols)
- local learning_rates = {
- 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
- }
- local penalty_weights = {
- 0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
- }
-- Start of perceptron training
-
local input_size = #all_symbols
torch.setnumthreads(opts['threads'])
local linear_module = nn.Linear(input_size, 1)
local best_weights = linear_module.weight[1]:clone()
local trainer = nn.StochasticGradient(perceptron, criterion)
- trainer.maxIteration = opts["iters"]
+ trainer.maxIteration = tonumber(opts["iters"])
trainer.verbose = opts['verbose']
trainer.hookIteration = function(self, iteration, error)