]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Fix random forests module
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 11 Nov 2017 16:32:54 +0000 (16:32 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 11 Nov 2017 16:32:54 +0000 (16:32 +0000)
contrib/torch/decisiontree/DataSet.lua
contrib/torch/decisiontree/GradientBoostTrainer.lua
contrib/torch/decisiontree/init.lua

index 505ec86d2cdaccdb933ca263e6cd9c4c2b8e70d0..15058a7c6a62e8697aab0ca6827abab09a6ffe79 100644 (file)
@@ -1,5 +1,4 @@
 local dt = require "decisiontree._env"
-local ipc = require 'libipc'
 
 local DataSet = torch.class("dt.DataSet", dt)
 
@@ -77,31 +76,10 @@ function DataSet:sortFeatureValues(inputs)
       assert(torch.isTensor(inputs))
       featureIds:range(1,inputs:size(2))
 
-      local wq = ipc.workqueue()
       for i=1,inputs:size(2) do
-         wq:write({i, inputs:select(2, i)})
-      end
-      for i=1,self.nThreads do
-         wq:write(nil)
-      end
-
-      ipc.map(self.nThreads, function(wq)
-         while true do
-            local data = wq:read()
-            if data == nil then break end
-            local featureId = data[1]
-            local values = data[2]
-            local sortFeatureValues, sortExampleIds = values:sort(1, false)
-            sortFeatureValues = nil
-            wq:write({featureId, sortExampleIds})
-            collectgarbage()
-         end
-      end, wq)
-
-      for _=1,inputs:size(2) do
-         local data = wq:read()
-         local featureId = data[1]
-         local sortedFeatureExampleIds = data[2]
+         local featureId = i
+         local values = inputs:select(2, i)
+         local _, sortedFeatureExampleIds = values:sort(1, false)
          dataset[featureId] = sortedFeatureExampleIds
       end
    end
index 54d1ba8ac8ee6b67f0a83686e1ddd6f9574058dc..51299b109e06736b43925c83cb69542fab06f5fc 100644 (file)
@@ -177,7 +177,7 @@ function GradientBoostTrainer:train(trainSet, featureIds, validSet, verbose)
          timer:reset()
          local stop, validLoss, bestDecisionForest = self:validate(trainSet, validSet, decisionForest, bestDecisionForest)
          if dt.PROFILE then print("validate tree time: "..timer:time().real) end
-         if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", trainLoss, validLoss)) end
+         if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", 0, validLoss)) end
          if stop then
             if verbose then print(string.format("GBDT early stopped on tree %d", treeId)) end
             break
index fdaf1b56a50ada398f4c6304d120a7b9227df3c8..26f790b605b4c5101712bd31798026f7d3efb36b 100644 (file)
@@ -1,10 +1,8 @@
 require 'paths'
-require 'xlua'
+--require 'xlua'
 require 'string'
 require 'os'
-require 'sys'
-require 'image'
-require 'lfs'
+--require 'sys'
 require 'nn'
 
 -- these actually return local variables but we will re-require them
@@ -31,7 +29,7 @@ require 'decisiontree.math'
 require 'decisiontree.utils'
 
 -- for multi-threading
-require 'decisiontree.WorkPool'
+--require 'decisiontree.WorkPool'
 
 -- abstract classes
 require 'decisiontree.DecisionTree'
@@ -62,8 +60,8 @@ require 'decisiontree.GradientBoostTrainer'
 require 'decisiontree.GradientBoostState' -- TreeState subclass
 
 -- unit tests and benchmarks
-require 'decisiontree.test'
-require 'decisiontree.benchmark'
+--require 'decisiontree.test'
+--require 'decisiontree.benchmark'
 
 -- nn.Module
 require 'decisiontree.DFD'