summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--contrib/torch/decisiontree/DataSet.lua28
-rw-r--r--contrib/torch/decisiontree/GradientBoostTrainer.lua2
-rw-r--r--contrib/torch/decisiontree/init.lua12
3 files changed, 9 insertions, 33 deletions
diff --git a/contrib/torch/decisiontree/DataSet.lua b/contrib/torch/decisiontree/DataSet.lua
index 505ec86d2..15058a7c6 100644
--- a/contrib/torch/decisiontree/DataSet.lua
+++ b/contrib/torch/decisiontree/DataSet.lua
@@ -1,5 +1,4 @@
local dt = require "decisiontree._env"
-local ipc = require 'libipc'
local DataSet = torch.class("dt.DataSet", dt)
@@ -77,31 +76,10 @@ function DataSet:sortFeatureValues(inputs)
assert(torch.isTensor(inputs))
featureIds:range(1,inputs:size(2))
- local wq = ipc.workqueue()
for i=1,inputs:size(2) do
- wq:write({i, inputs:select(2, i)})
- end
- for i=1,self.nThreads do
- wq:write(nil)
- end
-
- ipc.map(self.nThreads, function(wq)
- while true do
- local data = wq:read()
- if data == nil then break end
- local featureId = data[1]
- local values = data[2]
- local sortFeatureValues, sortExampleIds = values:sort(1, false)
- sortFeatureValues = nil
- wq:write({featureId, sortExampleIds})
- collectgarbage()
- end
- end, wq)
-
- for _=1,inputs:size(2) do
- local data = wq:read()
- local featureId = data[1]
- local sortedFeatureExampleIds = data[2]
+ local featureId = i
+ local values = inputs:select(2, i)
+ local _, sortedFeatureExampleIds = values:sort(1, false)
dataset[featureId] = sortedFeatureExampleIds
end
end
diff --git a/contrib/torch/decisiontree/GradientBoostTrainer.lua b/contrib/torch/decisiontree/GradientBoostTrainer.lua
index 54d1ba8ac..51299b109 100644
--- a/contrib/torch/decisiontree/GradientBoostTrainer.lua
+++ b/contrib/torch/decisiontree/GradientBoostTrainer.lua
@@ -177,7 +177,7 @@ function GradientBoostTrainer:train(trainSet, featureIds, validSet, verbose)
timer:reset()
local stop, validLoss, bestDecisionForest = self:validate(trainSet, validSet, decisionForest, bestDecisionForest)
if dt.PROFILE then print("validate tree time: "..timer:time().real) end
- if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", trainLoss, validLoss)) end
+ if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", 0, validLoss)) end
if stop then
if verbose then print(string.format("GBDT early stopped on tree %d", treeId)) end
break
diff --git a/contrib/torch/decisiontree/init.lua b/contrib/torch/decisiontree/init.lua
index fdaf1b56a..26f790b60 100644
--- a/contrib/torch/decisiontree/init.lua
+++ b/contrib/torch/decisiontree/init.lua
@@ -1,10 +1,8 @@
require 'paths'
-require 'xlua'
+--require 'xlua'
require 'string'
require 'os'
-require 'sys'
-require 'image'
-require 'lfs'
+--require 'sys'
require 'nn'
-- these actually return local variables but we will re-require them
@@ -31,7 +29,7 @@ require 'decisiontree.math'
require 'decisiontree.utils'
-- for multi-threading
-require 'decisiontree.WorkPool'
+--require 'decisiontree.WorkPool'
-- abstract classes
require 'decisiontree.DecisionTree'
@@ -62,8 +60,8 @@ require 'decisiontree.GradientBoostTrainer'
require 'decisiontree.GradientBoostState' -- TreeState subclass
-- unit tests and benchmarks
-require 'decisiontree.test'
-require 'decisiontree.benchmark'
+--require 'decisiontree.test'
+--require 'decisiontree.benchmark'
-- nn.Module
require 'decisiontree.DFD'