You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CartTrainer.lua 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. local dt = require "decisiontree._env"
  2. local _ = require "moses"
  3. local CartTrainer = torch.class("dt.CartTrainer", dt)
  4. -- Generic CART trainer
  5. function CartTrainer:__init(dataset, minLeafSize, maxLeafNodes)
  6. assert(torch.isTypeOf(dataset, 'dt.DataSet'))
  7. self.dataset = dataset
  8. self.minLeafSize = assert(minLeafSize) -- min examples per leaf
  9. self.maxLeafNodes = assert(maxLeafNodes) -- max leaf nodes in tree
  10. -- by default, single thread
  11. self.parallelMode = 'singlethread'
  12. end
  13. function CartTrainer:train(rootTreeState, activeFeatures)
  14. assert(torch.isTypeOf(rootTreeState, 'dt.TreeState'))
  15. assert(torch.isTensor(activeFeatures))
  16. local root = dt.CartNode()
  17. root.id = 0
  18. root.score = rootTreeState:score(self.dataset)
  19. local nleaf = 1
  20. -- TODO : nodeparallel: parallelize here. The queue is a workqueue.
  21. local queue = {}
  22. table.insert(queue, 1, {cartNode=root, treeState=rootTreeState})
  23. while #queue > 0 and nleaf < self.maxLeafNodes do
  24. local treeGrowerArgs = table.remove(queue, #queue)
  25. local currentTreeState = treeGrowerArgs.treeState
  26. -- Note: if minLeafSize = 1 and maxLeafNode = inf, then each example will be its own leaf...
  27. if self:hasEnoughTrainingExamplesToSplit(currentTreeState.exampleIds:size(1)) then
  28. nleaf = self:processNode(nleaf, queue, treeGrowerArgs.cartNode, currentTreeState, activeFeatures)
  29. end
  30. end
  31. -- CartTree with random branching (when feature is missing)
  32. local branchleft = function() return math.random() < 0.5 end
  33. return dt.CartTree(root, branchleft), nleaf
  34. end
  35. function CartTrainer:processNode(nleaf, queue, node, treeState, activeFeatures)
  36. local bestSplit
  37. if self.parallelMode == 'singlethread' then
  38. bestSplit = self:findBestSplitForAllFeatures(treeState, activeFeatures)
  39. elseif self.parallelMode == 'featureparallel' then
  40. bestSplit = self:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
  41. else
  42. error("Unrecognized parallel mode: " .. self.parallelMode)
  43. end
  44. if bestSplit then
  45. local leftTreeState, rightTreeState = treeState:branch(bestSplit, self.dataset)
  46. assert(bestSplit.leftChildSize + bestSplit.rightChildSize == leftTreeState.exampleIds:size(1) + rightTreeState.exampleIds:size(1), "The left and right subtrees don't match the split found!")
  47. self:setValuesAndCreateChildrenForNode(node, bestSplit, leftTreeState, rightTreeState, nleaf)
  48. table.insert(queue, 1, {cartNode=node.leftChild, treeState=leftTreeState})
  49. table.insert(queue, 1, {cartNode=node.rightChild, treeState=rightTreeState})
  50. return nleaf + 1
  51. end
  52. return nleaf
  53. end
  54. function CartTrainer:findBestSplitForAllFeatures(treeState, activeFeatures)
  55. local timer = torch.Timer()
  56. local bestSplit = treeState:findBestSplit(self.dataset, activeFeatures, self.minLeafSize, -1, -1)
  57. if bestSplit then
  58. assert(torch.type(bestSplit) == 'table')
  59. end
  60. if dt.PROFILE then
  61. print("findBestSplitForAllFeatures time="..timer:time().real)
  62. end
  63. return bestSplit
  64. end
  65. -- Updates the parentNode with the bestSplit information by creates left/right child Nodes.
  66. function CartTrainer:setValuesAndCreateChildrenForNode(parentNode, bestSplit, leftState, rightState, nleaf)
  67. assert(torch.isTypeOf(parentNode, 'dt.CartNode'))
  68. assert(torch.type(bestSplit) == 'table')
  69. assert(torch.isTypeOf(leftState, 'dt.TreeState'))
  70. assert(torch.isTypeOf(rightState, 'dt.TreeState'))
  71. assert(torch.type(nleaf) == 'number')
  72. local leftChild = dt.CartNode()
  73. leftChild.score = leftState:score(self.dataset)
  74. leftChild.nodeId = 2 * nleaf - 1
  75. local rightChild = dt.CartNode()
  76. rightChild.score = rightState:score(self.dataset)
  77. rightChild.nodeId = 2 * nleaf
  78. parentNode.splitFeatureId = bestSplit.splitId
  79. parentNode.splitFeatureValue = bestSplit.splitValue
  80. parentNode.leftChild = leftChild
  81. parentNode.rightChild = rightChild
  82. parentNode.splitGain = bestSplit.splitGain
  83. end
  84. -- We minimally need 2 * N examples in the parent to satisfy >= N examples per child
  85. function CartTrainer:hasEnoughTrainingExamplesToSplit(count)
  86. return count >= 2 * self.minLeafSize
  87. end
  88. -- call before training to enable feature-parallelization
  89. function CartTrainer:featureParallel(workPool)
  90. assert(self.parallelMode == 'singlethread', self.parallelMode)
  91. self.parallelMode = 'featureparallel'
  92. self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
  93. assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))
  94. -- this deletes all SparseTensor hash maps so that they aren't serialized
  95. self.dataset:deleteIndex()
  96. -- require the dt package
  97. self.workPool:update('require', {libname='decisiontree',varname='dt'})
  98. -- setup worker store (each worker will have its own copy)
  99. local store = {
  100. dataset=self.dataset,
  101. minLeafSize=self.minLeafSize
  102. }
  103. self.workPool:update('storeKeysValues', store)
  104. end
  105. -- feature parallel
  106. function CartTrainer:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
  107. local timer = torch.Timer()
  108. local bestSplit
  109. if treeState.findBestSplitFP then
  110. bestSplit = treeState:findBestSplitFP(self.dataset, activeFeatures, self.minLeafSize, self.workPool.nThread)
  111. end
  112. if not bestSplit then
  113. for i=1,self.workPool.nThread do
  114. -- upvalues
  115. local treeState = treeState
  116. local shardId = i
  117. local nShard = self.workPool.nThread
  118. local featureIds = activeFeatures
  119. -- closure
  120. local task = function(store)
  121. assert(store.dataset)
  122. assert(store.minLeafSize)
  123. if treeState.threadInitialize then
  124. treeState:threadInitialize()
  125. end
  126. local bestSplit = treeState:findBestSplit(store.dataset, featureIds, store.minLeafSize, shardId, nShard)
  127. return bestSplit
  128. end
  129. self.workPool:writeup('execute', task)
  130. end
  131. for i=1,self.workPool.nThread do
  132. local taskname, candidateSplit = self.workPool:read()
  133. assert(taskname == 'execute')
  134. if candidateSplit then
  135. if ((not bestSplit) or candidateSplit.splitGain < bestSplit.splitGain) then
  136. bestSplit = candidateSplit
  137. end
  138. end
  139. end
  140. end
  141. if bestSplit then
  142. assert(torch.type(bestSplit) == 'table')
  143. end
  144. if dt.PROFILE then
  145. print("findBestSplitForAllFeaturesFP time="..timer:time().real)
  146. end
  147. return bestSplit
  148. end