123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- local dt = require "decisiontree._env"
- local _ = require "moses"
-
- local CartTrainer = torch.class("dt.CartTrainer", dt)
-
- -- Generic CART trainer
- function CartTrainer:__init(dataset, minLeafSize, maxLeafNodes)
- assert(torch.isTypeOf(dataset, 'dt.DataSet'))
- self.dataset = dataset
- self.minLeafSize = assert(minLeafSize) -- min examples per leaf
- self.maxLeafNodes = assert(maxLeafNodes) -- max leaf nodes in tree
-
- -- by default, single thread
- self.parallelMode = 'singlethread'
- end
-
- function CartTrainer:train(rootTreeState, activeFeatures)
- assert(torch.isTypeOf(rootTreeState, 'dt.TreeState'))
- assert(torch.isTensor(activeFeatures))
- local root = dt.CartNode()
- root.id = 0
- root.score = rootTreeState:score(self.dataset)
-
- local nleaf = 1
-
- -- TODO : nodeparallel: parallelize here. The queue is a workqueue.
- local queue = {}
- table.insert(queue, 1, {cartNode=root, treeState=rootTreeState})
-
- while #queue > 0 and nleaf < self.maxLeafNodes do
- local treeGrowerArgs = table.remove(queue, #queue)
- local currentTreeState = treeGrowerArgs.treeState
-
- -- Note: if minLeafSize = 1 and maxLeafNode = inf, then each example will be its own leaf...
- if self:hasEnoughTrainingExamplesToSplit(currentTreeState.exampleIds:size(1)) then
- nleaf = self:processNode(nleaf, queue, treeGrowerArgs.cartNode, currentTreeState, activeFeatures)
- end
- end
-
- -- CartTree with random branching (when feature is missing)
- local branchleft = function() return math.random() < 0.5 end
- return dt.CartTree(root, branchleft), nleaf
- end
-
- function CartTrainer:processNode(nleaf, queue, node, treeState, activeFeatures)
- local bestSplit
- if self.parallelMode == 'singlethread' then
- bestSplit = self:findBestSplitForAllFeatures(treeState, activeFeatures)
- elseif self.parallelMode == 'featureparallel' then
- bestSplit = self:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
- else
- error("Unrecognized parallel mode: " .. self.parallelMode)
- end
-
- if bestSplit then
- local leftTreeState, rightTreeState = treeState:branch(bestSplit, self.dataset)
- assert(bestSplit.leftChildSize + bestSplit.rightChildSize == leftTreeState.exampleIds:size(1) + rightTreeState.exampleIds:size(1), "The left and right subtrees don't match the split found!")
- self:setValuesAndCreateChildrenForNode(node, bestSplit, leftTreeState, rightTreeState, nleaf)
-
- table.insert(queue, 1, {cartNode=node.leftChild, treeState=leftTreeState})
- table.insert(queue, 1, {cartNode=node.rightChild, treeState=rightTreeState})
-
- return nleaf + 1
- end
-
- return nleaf
- end
-
- function CartTrainer:findBestSplitForAllFeatures(treeState, activeFeatures)
- local timer = torch.Timer()
- local bestSplit = treeState:findBestSplit(self.dataset, activeFeatures, self.minLeafSize, -1, -1)
-
- if bestSplit then
- assert(torch.type(bestSplit) == 'table')
- end
-
- if dt.PROFILE then
- print("findBestSplitForAllFeatures time="..timer:time().real)
- end
- return bestSplit
- end
-
- -- Updates the parentNode with the bestSplit information by creates left/right child Nodes.
- function CartTrainer:setValuesAndCreateChildrenForNode(parentNode, bestSplit, leftState, rightState, nleaf)
- assert(torch.isTypeOf(parentNode, 'dt.CartNode'))
- assert(torch.type(bestSplit) == 'table')
- assert(torch.isTypeOf(leftState, 'dt.TreeState'))
- assert(torch.isTypeOf(rightState, 'dt.TreeState'))
- assert(torch.type(nleaf) == 'number')
-
- local leftChild = dt.CartNode()
- leftChild.score = leftState:score(self.dataset)
- leftChild.nodeId = 2 * nleaf - 1
-
- local rightChild = dt.CartNode()
- rightChild.score = rightState:score(self.dataset)
- rightChild.nodeId = 2 * nleaf
-
- parentNode.splitFeatureId = bestSplit.splitId
- parentNode.splitFeatureValue = bestSplit.splitValue
- parentNode.leftChild = leftChild
- parentNode.rightChild = rightChild
- parentNode.splitGain = bestSplit.splitGain
- end
-
- -- We minimally need 2 * N examples in the parent to satisfy >= N examples per child
- function CartTrainer:hasEnoughTrainingExamplesToSplit(count)
- return count >= 2 * self.minLeafSize
- end
-
- -- call before training to enable feature-parallelization
- function CartTrainer:featureParallel(workPool)
- assert(self.parallelMode == 'singlethread', self.parallelMode)
- self.parallelMode = 'featureparallel'
- self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
- assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))
-
- -- this deletes all SparseTensor hash maps so that they aren't serialized
- self.dataset:deleteIndex()
-
- -- require the dt package
- self.workPool:update('require', {libname='decisiontree',varname='dt'})
- -- setup worker store (each worker will have its own copy)
- local store = {
- dataset=self.dataset,
- minLeafSize=self.minLeafSize
- }
- self.workPool:update('storeKeysValues', store)
- end
-
- -- feature parallel
- function CartTrainer:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
- local timer = torch.Timer()
- local bestSplit
- if treeState.findBestSplitFP then
- bestSplit = treeState:findBestSplitFP(self.dataset, activeFeatures, self.minLeafSize, self.workPool.nThread)
- end
-
- if not bestSplit then
- for i=1,self.workPool.nThread do
- -- upvalues
- local treeState = treeState
- local shardId = i
- local nShard = self.workPool.nThread
- local featureIds = activeFeatures
- -- closure
- local task = function(store)
- assert(store.dataset)
- assert(store.minLeafSize)
- if treeState.threadInitialize then
- treeState:threadInitialize()
- end
-
- local bestSplit = treeState:findBestSplit(store.dataset, featureIds, store.minLeafSize, shardId, nShard)
- return bestSplit
- end
-
- self.workPool:writeup('execute', task)
- end
-
- for i=1,self.workPool.nThread do
- local taskname, candidateSplit = self.workPool:read()
- assert(taskname == 'execute')
- if candidateSplit then
- if ((not bestSplit) or candidateSplit.splitGain < bestSplit.splitGain) then
- bestSplit = candidateSplit
- end
- end
- end
- end
-
- if bestSplit then
- assert(torch.type(bestSplit) == 'table')
- end
-
- if dt.PROFILE then
- print("findBestSplitForAllFeaturesFP time="..timer:time().real)
- end
- return bestSplit
- end
|