]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Rewrite model and learning logic for rescore
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 6 Mar 2018 17:03:49 +0000 (17:03 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 6 Mar 2018 17:03:49 +0000 (17:03 +0000)
- Add more optimization methods
- Implement l1/l2 regulation
- Improve usability

lualib/rspamadm/rescore.lua

index 5ef1c4267bca4024566fd10c58f1095f642ca811..448c362570b64f92222c439c1cc3c8d794054e2f 100644 (file)
@@ -4,6 +4,8 @@ local lua_util = require "lua_util"
 local ucl = require "ucl"
 local logger = require "rspamd_logger"
 local getopt = require "rspamadm/getopt"
+local optim = require "optim"
+local rspamd_util = require "rspamd_util"
 
 local rescore_utility = require "rspamadm/rescore_utility"
 
@@ -60,13 +62,10 @@ local function init_weights(all_symbols, original_symbol_scores)
 
   local weights = torch.Tensor(#all_symbols)
 
-  local mean = 0
-
   for i, symbol in pairs(all_symbols) do
     local score = original_symbol_scores[symbol]
     if not score then score = 0 end
     weights[i] = score
-    mean = mean + score
   end
 
   return weights
@@ -209,17 +208,175 @@ Overall accuracy: %.2f %%
 
 end
 
+-- training function
+local function train(dataset, opt, model, criterion, epoch,
+                     all_symbols)
+  -- epoch tracker
+  epoch = epoch or 1
+
+  -- local vars
+  local time = rspamd_util.get_ticks()
+  local confusion = optim.ConfusionMatrix({'ham', 'spam'})
+
+  -- do one epoch
+
+  local lbfgsState
+  local sgdState
+
+  local batch_size = opt.batch_size
+
+  logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size)
+
+  for t = 1,dataset:size(),batch_size do
+    -- create mini batch
+    local k = 1
+    local last = math.min(t + batch_size - 1, dataset:size())
+    local inputs = torch.Tensor(last - t + 1, #all_symbols)
+    local targets = torch.Tensor(last - t + 1)
+    for i = t,last do
+      -- load new sample
+      local sample = dataset[i]
+      local input = sample[1]:clone()
+      local target = sample[2]:clone()
+      --target = target:squeeze()
+      inputs[k] = input
+      targets[k] = target
+      k = k + 1
+    end
+
+    local parameters,gradParameters = model:getParameters()
+
+    -- create closure to evaluate f(X) and df/dX
+    local feval = function(x)
+      -- just in case:
+      collectgarbage()
+
+      -- get new parameters
+      if x ~= parameters then
+        parameters:copy(x)
+      end
+
+      -- reset gradients
+      gradParameters:zero()
+
+      -- evaluate function for complete mini batch
+      local outputs = model:forward(inputs)
+      local f = criterion:forward(outputs, targets)
+
+      -- estimate df/dW
+      local df_do = criterion:backward(outputs, targets)
+      model:backward(inputs, df_do)
+
+      -- penalties (L1 and L2):
+      local l1 = tonumber(opt.l1) or 0
+      local l2 = tonumber(opt.l1) or 0
+      if l1 ~= 0 or l2 ~= 0 then
+        -- locals:
+        local norm,sign= torch.norm,torch.sign
+
+        -- Loss:
+        f = f + l1 * norm(parameters,1)
+        f = f + l2 * norm(parameters,2)^2/2
+
+        -- Gradients:
+        gradParameters:add( sign(parameters):mul(l1) + parameters:clone():mul(l2) )
+      end
+
+      -- update confusion
+      for i = 1,(last - t + 1) do
+        local class_predicted = 0
+        if outputs[i][1] > 0.5 then class_predicted = 1 end
+        confusion:add(class_predicted + 1, targets[i] + 1)
+      end
+
+      -- return f and df/dX
+      return f,gradParameters
+    end
+
+    -- optimize on current mini-batch
+    if opt.optimization == 'LBFGS' then
+
+      -- Perform LBFGS step:
+      lbfgsState = lbfgsState or {
+        maxIter = opt.iters,
+        lineSearch = optim.lswolfe
+      }
+      optim.lbfgs(feval, parameters, lbfgsState)
+
+      -- disp report:
+      logger.messagex('LBFGS step')
+      logger.messagex(' - progress in batch: ' .. t .. '/' .. dataset:size())
+      logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter)
+      logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval)
+
+    elseif opt.optimization == 'ADAM' then
+      sgdState = sgdState or {
+        learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+        momentum = tonumber(opts.momentum), -- opt.momentum,
+        learningRateDecay = tonumber(opts.learning_rate_decay),
+        weightDecay = tonumber(opts.weight_decay),
+      }
+      optim.adam(feval, parameters, sgdState)
+    elseif opt.optimization == 'ADAGRAD' then
+      sgdState = sgdState or {
+        learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+        momentum = tonumber(opts.momentum), -- opt.momentum,
+        learningRateDecay = tonumber(opts.learning_rate_decay),
+        weightDecay = tonumber(opts.weight_decay),
+      }
+      optim.adagrad(feval, parameters, sgdState)
+    elseif opt.optimization == 'SGD' then
+      sgdState = sgdState or {
+        learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+        momentum = tonumber(opts.momentum), -- opt.momentum,
+        learningRateDecay = tonumber(opts.learning_rate_decay),
+        weightDecay = tonumber(opts.weight_decay),
+      }
+      optim.sgd(feval, parameters, sgdState)
+    elseif opt.optimization == 'NAG' then
+      sgdState = sgdState or {
+        learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
+        momentum = tonumber(opts.momentum), -- opt.momentum,
+        learningRateDecay = tonumber(opts.learning_rate_decay),
+        weightDecay = tonumber(opts.weight_decay),
+      }
+      optim.nag(feval, parameters, sgdState)
+    else
+      error('unknown optimization method')
+    end
+  end
+
+  -- time taken
+  time = rspamd_util.get_ticks() - time
+  time = time / dataset:size()
+  logger.messagex("time to learn 1 sample = " .. (time*1000) .. 'ms')
+
+  -- logger.messagex confusion matrix
+  logger.messagex('confusion: %s', tostring(confusion))
+  logger.messagex('%s mean class accuracy (train set)', confusion.totalValid * 100)
+  confusion:zero()
+
+  epoch = epoch + 1
+end
+
+
 local default_opts = {
   verbose = true,
   iters = 10,
   threads = 1,
+  batch_size = 1000,
+  optimization = 'ADAM',
+  learning_rate_decay = 0.001,
+  momentum = 0.0,
+  l1 = 0.0,
+  l2 = 0.0,
 }
 
 local learning_rates = {
-  0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2.5, 5, 7.5, 10
+  0.01, 0.05, 0.1
 }
 local penalty_weights = {
-  0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5, 10, 15, 20, 25, 50, 75, 100
+  0, 0.001, 0.01, 0.1, 0.5
 }
 
 local function override_defaults(def, override)
@@ -317,13 +474,12 @@ return function (args, cfg)
   -- Start of perceptron training
   local input_size = #all_symbols
   torch.setnumthreads(opts['threads'])
+
   local linear_module = nn.Linear(input_size, 1)
+  local activation = nn.Tanh()
 
   local perceptron = nn.Sequential()
   perceptron:add(linear_module)
-
-  local activation = nn.Sigmoid()
-
   perceptron:add(activation)
 
   local criterion = nn.MSECriterion()
@@ -331,13 +487,17 @@ return function (args, cfg)
 
   local best_fscore = -math.huge
   local best_weights = linear_module.weight[1]:clone()
+  local best_learning_rate
+  local best_weight_decay
 
-  local trainer = nn.StochasticGradient(perceptron, criterion)
-  trainer.maxIteration = tonumber(opts["iters"])
-  trainer.verbose = opts['verbose']
-  trainer.hookIteration = function(self, iteration, error)
-
-    if iteration == trainer.maxIteration then
+  for _,lr in ipairs(learning_rates) do
+    for _,wd in ipairs(penalty_weights) do
+      linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
+      opts.learning_rate = lr
+      opts.weight_decay = wd
+      for i=1,tonumber(opts.iters) do
+        train(dataset, opts, perceptron, criterion, i, all_symbols)
+      end
 
       local fscore = calculate_fscore_from_weights(cv_logs,
           all_symbols,
@@ -345,29 +505,18 @@ return function (args, cfg)
           linear_module.bias[1],
           threshold)
 
-      logger.messagex("Cross-validation fscore: %s", fscore)
+      logger.messagex("Cross-validation fscore=%s, learning rate=%s, weight decay=%s",
+          fscore, lr, wd)
 
       if best_fscore < fscore then
+        best_learning_rate = lr
+        best_weight_decay = wd
         best_fscore = fscore
         best_weights = linear_module.weight[1]:clone()
       end
     end
   end
 
-  for _, learning_rate in ipairs(learning_rates) do
-    for _, weight in ipairs(penalty_weights) do
-
-      trainer.weightDecay = weight
-      logger.messagex("Learning with learning_rate: %s, l2_weight: %s",
-          learning_rate, weight)
-
-      linear_module.weight[1] = init_weights(all_symbols, original_symbol_scores)
-
-      trainer.learningRate = learning_rate
-      trainer:train(dataset)
-    end
-  end
-
   -- End perceptron training
 
   local new_symbol_scores = best_weights
@@ -392,4 +541,7 @@ return function (args, cfg)
   test_logs = update_logs(test_logs, new_symbol_scores)
   logger.message("\n\nPost-rescore test stats\n")
   print_stats(test_logs, threshold)
+
+  logger.messagex('Best fscore=%s, best learning rate=%s, best weight decay=%s',
+      best_fscore, best_learning_rate, best_weight_decay)
 end
\ No newline at end of file