aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/RandomForestTrainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/decisiontree/RandomForestTrainer.lua')
-rw-r--r--contrib/lua-torch/decisiontree/RandomForestTrainer.lua159
1 files changed, 159 insertions, 0 deletions
diff --git a/contrib/lua-torch/decisiontree/RandomForestTrainer.lua b/contrib/lua-torch/decisiontree/RandomForestTrainer.lua
new file mode 100644
index 000000000..41040b25b
--- /dev/null
+++ b/contrib/lua-torch/decisiontree/RandomForestTrainer.lua
@@ -0,0 +1,159 @@
+local dt = require "decisiontree._env"
+
+local RandomForestTrainer = torch.class("dt.RandomForestTrainer", dt)
+
+function RandomForestTrainer:__init(opt)
+ assert(torch.type(opt.nTree) == 'number')
+ assert(opt.nTree > 0)
+ self.nTree = opt.nTree
+ -- max number of leaf nodes per tree
+ assert(torch.type(opt.maxLeafNodes) == 'number')
+ assert(opt.maxLeafNodes > 0)
+ self.maxLeafNodes = opt.maxLeafNodes
+ -- min number of examples per leaf
+ assert(torch.type(opt.minLeafSize) == 'number')
+ assert(opt.minLeafSize > 0)
+ self.minLeafSize = opt.minLeafSize
+
+ -- when non-positive, defaults to sqrt(#feature)
+ assert(torch.type(opt.featureBaggingSize) == 'number')
+ self.featureBaggingSize = opt.featureBaggingSize
+
+ assert(torch.type(opt.activeRatio) == 'number')
+ assert(opt.activeRatio > 0)
+ self.activeRatio = opt.activeRatio
+
+ -- default parallelization is singlethread
+ self.parallelMode = 'singlethread'
+end
+
+-- Train a DecisionForest
+function RandomForestTrainer:train(trainSet, featureIds, verbose)
+ assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
+ assert(torch.type(featureIds) == 'torch.LongTensor')
+
+ if verbose then print(string.format("Begin training Decision Forest with %d trees", self.nTree)) end
+
+ local weight = torch.Tensor(self.nTree):fill(1 / self.nTree) -- RF uses uniform weights
+
+ local trees
+ if self.parallelMode == 'singlethread' then
+ trees = self:trainTrees(trainSet, featureIds, verbose)
+ elseif self.parallelMode == 'treeparallel' then
+ trainSet:deleteIndex() -- prevents serialization bottleneck
+ trees = self:trainTreesTP(trainSet, featureIds, verbose)
+ else
+ error("Unrecognized parallel mode: " .. self.parallelMode)
+ end
+
+ if verbose then print(string.format("Successfully trained %d trees", #trees)) end
+
+ -- set bias
+ local bias = 0;
+ for i, tree in ipairs(trees) do
+ bias = bias + tree.root.score * weight[i]
+ end
+
+ return dt.DecisionForest(trees, weight, bias)
+end
+
+function RandomForestTrainer:trainTrees(trainSet, featureIds, verbose)
+
+ -- the same CartTrainer will be used for each tree
+ local cartTrainer = dt.CartTrainer(trainSet, self.minLeafSize, self.maxLeafNodes)
+
+ local trees = {}
+ for treeId=1,self.nTree do
+ -- Train a CartTree
+ local tree = self.trainTree(cartTrainer, featureIds, self.featureBaggingSize, self.activeRatio, treeId, verbose)
+ table.insert(trees, tree)
+ end
+ return trees
+end
+
+-- static function that returns a cartTree
+function RandomForestTrainer.trainTree(cartTrainer, featureIds, baggingSize, activeRatio, treeId, verbose)
+ assert(torch.isTypeOf(cartTrainer, 'dt.CartTrainer'))
+ assert(torch.type(featureIds) == 'torch.LongTensor')
+ local baggingSize = baggingSize > 0 and baggingSize or torch.round(math.sqrt(featureIds:size(1)))
+
+ if verbose then
+ print(string.format("Tree %d: Creating features bootstrap sample with baggingSize %d, nFeatures %d", treeId, baggingSize, featureIds:size(1)))
+ end
+
+ local trainSet = cartTrainer.dataset
+
+ -- sample boot strap features
+ local baggingIndices = torch.LongTensor(baggingSize):random(1,featureIds:size(1))
+ local activeFeatures = featureIds:index(1, baggingIndices)
+
+ -- sample boot strap examples
+ local sampleSize = torch.round(trainSet:size() * activeRatio)
+ if verbose then print(string.format("Creating bootstrap sample created of size %d", sampleSize)) end
+
+ baggingIndices:resize(sampleSize):random(1,trainSet:size())
+ local bootStrapExampleIds = torch.LongTensor()
+ bootStrapExampleIds:index(trainSet:getExampleIds(), 1, baggingIndices)
+
+ local cartTree = cartTrainer:train(dt.GiniState(bootStrapExampleIds), activeFeatures)
+
+ if verbose then print(string.format("Complete processing tree number %d", treeId)) end
+
+ return cartTree
+end
+
+-- call before training to enable tree-level parallelization
+function RandomForestTrainer:treeParallel(workPool)
+ assert(self.parallelMode == 'singlethread', self.parallelMode)
+ self.parallelMode = 'treeparallel'
+ self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
+ assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))
+
+ -- require the dt package
+ self.workPool:update('require', {libname='decisiontree',varname='dt'})
+end
+
+-- TP is for tree parallel (not toilet paper)
+function RandomForestTrainer:trainTreesTP(trainSet, featureIds, verbose)
+ assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
+ assert(torch.type(featureIds) == 'torch.LongTensor')
+ local minLeafSize = self.minLeafSize
+ local maxLeafNodes = self.maxLeafNodes
+
+ -- setup worker store (each worker will have its own cartTrainer)
+ self.workPool:updateup('execute', function(store)
+ local dt = require 'decisiontree'
+
+ store.cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes)
+ store.featureIds = featureIds
+ end)
+
+ for treeId=1,self.nTree do
+ -- upvalues
+ local baggingSize = self.featureBaggingSize
+ local activeRatio = self.activeRatio
+ -- task closure that will be executed in worker-thread
+ local function trainTreeTask(store)
+ local dt = require 'decisiontree'
+ return dt.RandomForestTrainer.trainTree(store.cartTrainer, store.featureIds, baggingSize, activeRatio, treeId, verbose)
+ end
+ self.workPool:writeup('execute', trainTreeTask)
+ end
+
+ local trees = {}
+ for treeId=1,self.nTree do
+ local taskname, tree = self.workPool:read()
+ assert(taskname=='execute')
+ assert(torch.isTypeOf(tree, 'dt.CartTree'))
+ table.insert(trees, tree)
+ end
+ return trees
+end
+
+function RandomForestTrainer:getName()
+ return string.format(
+ "randomforest-aRatio-%4.2f-maxLeaf-%d-minExample-%d-nTree-%d",
+ self.activeRatio, self.maxLeafNodes, self.minLeafSize, self.nTree
+ )
+end
+