]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Add more features to rescore utility
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 3 Mar 2018 13:36:07 +0000 (13:36 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 3 Mar 2018 13:36:07 +0000 (13:36 +0000)
- Allow to ignore specific symbols
- Allow to specify learning rates and weight penalty manually

lualib/rspamadm/rescore.lua
lualib/rspamadm/rescore_utility.lua

index d76dc3861bf2014a15cf8bba7ebae8d9bfdd4919..6dcecc44b2e2eda1d38f523cdcff353353644dbe 100644 (file)
@@ -8,6 +8,10 @@ local getopt = require "rspamadm/getopt"
 local rescore_utility = require "rspamadm/rescore_utility"
 
 local opts
+local ignore_symbols = {
+  ['DATE_IN_PAST'] =true,
+  ['DATE_IN_FUTURE'] = true,
+}
 
 local function make_dataset_from_logs(logs, all_symbols)
   -- Returns a list of {input, output} for torch SGD train
@@ -28,7 +32,9 @@ local function make_dataset_from_logs(logs, all_symbols)
     local symbols_set = {}
 
     for i=4,#log do
-      symbols_set[log[i]] = true
+      if not ignore_symbols[log[i]] then
+        symbols_set[log[i]] = true
+      end
     end
 
     for index, symbol in pairs(all_symbols) do
@@ -209,6 +215,13 @@ local default_opts = {
   threads = 1,
 }
 
+local learning_rates = {
+  0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
+}
+local penalty_weights = {
+  0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
+}
+
 local function override_defaults(def, override)
   for k,v in pairs(override) do
     if def[k] then
@@ -235,11 +248,63 @@ end
 
 return function (args, cfg)
   opts = default_opts
-  override_defaults(opts, getopt.getopt(args, ''))
+  override_defaults(opts, getopt.getopt(args, 'i:'))
   local threshold = get_threshold()
   local logs = rescore_utility.get_all_logs(cfg["logdir"])
-  local all_symbols = rescore_utility.get_all_symbols(logs)
-  local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config)
+
+  if opts['ignore-symbol'] then
+    local function add_ignore(s)
+      ignore_symbols[s] = true
+    end
+    if type(opts['ignore-symbol']) == 'table' then
+      for _,s in ipairs(opts['ignore-symbol']) do
+        add_ignore(s)
+      end
+    else
+      add_ignore(opts['ignore-symbol'])
+    end
+  end
+
+  if opts['learning-rate'] then
+    learning_rates = {}
+
+    local function add_rate(r)
+      if tonumber(r) then
+        table.insert(learning_rates, tonumber(r))
+      end
+    end
+    if type(opts['learning-rate']) == 'table' then
+      for _,s in ipairs(opts['learning-rate']) do
+        add_rate(s)
+      end
+    else
+      add_rate(opts['learning-rate'])
+    end
+  end
+
+  if opts['penalty-weight'] then
+    penalty_weights = {}
+
+    local function add_weight(r)
+      if tonumber(r) then
+        table.insert(penalty_weights, tonumber(r))
+      end
+    end
+    if type(opts['penalty-weight']) == 'table' then
+      for _,s in ipairs(opts['penalty-weight']) do
+        add_weight(s)
+      end
+    else
+      add_weight(opts['penalty-weight'])
+    end
+  end
+
+  if opts['i'] then opts['iters'] = opts['i'] end
+  logger.errx('%s', opts)
+
+  local all_symbols = rescore_utility.get_all_symbols(logs, ignore_symbols)
+  local original_symbol_scores = rescore_utility.get_all_symbol_scores(rspamd_config,
+      ignore_symbols)
 
   shuffle(logs)
   torch.setdefaulttensortype('torch.FloatTensor')
@@ -249,15 +314,8 @@ return function (args, cfg)
 
   local dataset = make_dataset_from_logs(train_logs, all_symbols)
 
-  local learning_rates = {
-    0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
-  }
-  local penalty_weights = {
-    0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
-  }
 
   -- Start of perceptron training
-
   local input_size = #all_symbols
   torch.setnumthreads(opts['threads'])
   local linear_module = nn.Linear(input_size, 1)
@@ -276,7 +334,7 @@ return function (args, cfg)
   local best_weights = linear_module.weight[1]:clone()
 
   local trainer = nn.StochasticGradient(perceptron, criterion)
-  trainer.maxIteration = opts["iters"]
+  trainer.maxIteration = tonumber(opts["iters"])
   trainer.verbose = opts['verbose']
   trainer.hookIteration = function(self, iteration, error)
 
index db79bbf7b79d9e0d622bfe2d469658686d030e06..7f3f4007847350e01e83c3cad2509bc432713e56 100644 (file)
@@ -4,7 +4,7 @@ local fun = require "fun"
 
 local utility = {}
 
-function utility.get_all_symbols(logs)
+function utility.get_all_symbols(logs, ignore_symbols)
   -- Returns a list of all symbols
 
   local symbols_set = {}
@@ -22,7 +22,9 @@ function utility.get_all_symbols(logs)
   local all_symbols = {}
 
   for symbol, _ in pairs(symbols_set) do
-    all_symbols[#all_symbols + 1] = symbol
+    if not ignore_symbols[symbol] then
+      all_symbols[#all_symbols + 1] = symbol
+    end
   end
 
   table.sort(all_symbols)
@@ -65,12 +67,14 @@ function utility.get_all_logs(dir_path)
   return all_logs
 end
 
-function utility.get_all_symbol_scores(conf)
+function utility.get_all_symbol_scores(conf, ignore_symbols)
   local counters = conf:get_symbols_counters()
 
   return fun.tomap(fun.map(function(elt)
     return elt['symbol'],elt['weight']
-  end, counters))
+  end, fun.filter(function(elt)
+    return not ignore_symbols[elt['symbol']]
+  end, counters)))
 end
 
 function utility.generate_statistics_from_logs(logs, threshold)