You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TreeState.lua 6.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. local dt = require "decisiontree._env"
  2. local TreeState = torch.class("dt.TreeState", dt)
  3. -- Holds the state of a subtree during decision tree training.
  4. -- Also, manages the state of candidate splits
  5. function TreeState:__init(exampleIds)
  6. assert(torch.type(exampleIds) == 'torch.LongTensor')
  7. self.exampleIds = exampleIds
  8. self.nExampleInLeftBranch = 0
  9. self.nExampleInRightBranch = 0
  10. end
  11. -- computes and returns the score of the node based on its examples
  12. function TreeState:score(dataset)
  13. error"NotImplemented"
  14. end
  15. -- Initializes the split-state-updater. Initially all examples are in the left branch.
  16. -- exampleIdsWithFeature is list of examples to split (those having a particular feature)
  17. function TreeState:initialize(exampleIdsWithFeature, dataset)
  18. error"NotImplemented"
  19. end
  20. -- Update the split state. This call has the effect of shifting the example from the left to the right branch.
  21. function TreeState:update(exampleId, dataset)
  22. error"NotImplemented"
  23. end
  24. -- Computes the SplitInfo determined by the current split state
  25. -- @param splitFeatureId the feature id of the split feature
  26. -- @param splitFeatureValue the feature value of the split feature
  27. -- @return the SplitInfo determined by the current split state
  28. function TreeState:computeSplitInfo(splitFeatureId, splitFeatureValue)
  29. error"NotImplemented"
  30. end
  31. -- bottleneck
  32. function TreeState:findBestFeatureSplit(dataset, featureId, minLeafSize)
  33. local dt = require "decisiontree"
  34. assert(torch.isTypeOf(dataset, 'dt.DataSet'))
  35. assert(torch.type(featureId) == 'number')
  36. assert(torch.type(minLeafSize) == 'number')
  37. -- all dataset example having this feature, sorted by value
  38. local featureExampleIds = dataset:getSortedFeature(featureId)
  39. local buffer = dt.getBufferTable('TreeState')
  40. buffer.longtensor = buffer.longtensor or torch.LongTensor()
  41. local exampleIdsWithFeature = buffer.longtensor
  42. -- map and tensor of examples containing feature:
  43. local exampleMap = {}
  44. local getExampleFeatureValue
  45. local j = 0
  46. if torch.type(dataset.input) == 'table' then
  47. exampleIdsWithFeature:resize(self.exampleIds:size())
  48. self.exampleIds:apply(function(exampleId)
  49. local input = dataset.input[exampleId]
  50. input:buildIndex()-- only builds index first time
  51. if input[featureId] then
  52. j = j + 1
  53. exampleIdsWithFeature[j] = exampleId
  54. exampleMap[exampleId] = j
  55. end
  56. end)
  57. if j == 0 then
  58. return
  59. end
  60. exampleIdsWithFeature:resize(j)
  61. getExampleFeatureValue = function(exampleId) return dataset.input[exampleId][featureId] end
  62. else
  63. exampleIdsWithFeature = self.exampleIds
  64. self.exampleIds:apply(function(exampleId)
  65. j = j + 1
  66. exampleMap[exampleId] = j
  67. end)
  68. local featureValues = dataset.input:select(2,featureId)
  69. getExampleFeatureValue = function(exampleId) return featureValues[exampleId] end
  70. end
  71. self:initialize(exampleIdsWithFeature, dataset)
  72. -- bottleneck
  73. local bestSplit, previousSplitValue, _tictoc
  74. for i=featureExampleIds:size(1),1,-1 do -- loop over examples sorted (desc) by feature value
  75. local exampleId = featureExampleIds[i]
  76. local exampleIdx = exampleMap[exampleId]
  77. if exampleIdx then
  78. local splitValue = getExampleFeatureValue(exampleId)
  79. if previousSplitValue and math.abs(splitValue - previousSplitValue) > dt.EPSILON then
  80. local splitInfo = self:computeSplitInfo(featureId, previousSplitValue, _tictoc)
  81. if (splitInfo.leftChildSize >= minLeafSize) and (splitInfo.rightChildSize >= minLeafSize) then
  82. if (not bestSplit) or (splitInfo.splitGain < bestSplit.splitGain) then
  83. _tictoc = bestSplit or {} -- reuse table
  84. bestSplit = splitInfo
  85. end
  86. end
  87. end
  88. previousSplitValue = splitValue
  89. -- bottleneck
  90. self:update(exampleId, dataset, exampleIdx)
  91. end
  92. end
  93. return bestSplit
  94. end
  95. -- finds the best split of examples in treeState among featureIds
  96. function TreeState:findBestSplit(dataset, featureIds, minLeafSize, shardId, nShard)
  97. assert(torch.isTypeOf(dataset, 'dt.DataSet'))
  98. assert(torch.type(featureIds) == 'torch.LongTensor')
  99. assert(torch.type(minLeafSize) == 'number')
  100. assert(torch.type(shardId) == 'number')
  101. assert(torch.type(nShard) == 'number')
  102. local bestSplit
  103. for i=1,featureIds:size(1) do
  104. local featureId = featureIds[i]
  105. if (nShard <= 1) or ( (featureId % nShard) + 1 == shardId ) then -- feature sharded
  106. local splitCandidate = self:findBestFeatureSplit(dataset, featureId, minLeafSize)
  107. if splitCandidate and ((not bestSplit) or (splitCandidate.splitGain < bestSplit.splitGain)) then
  108. bestSplit = splitCandidate
  109. end
  110. end
  111. end
  112. return bestSplit
  113. end
  114. -- Partitions self given a splitInfo table, producing a pair of exampleIds corresponding to the left and right subtrees.
  115. function TreeState:_branch(splitInfo, dataset)
  116. local leftIdx, rightIdx = 0, 0
  117. local nExample = self.exampleIds:size(1)
  118. local splitExampleIds = torch.LongTensor(nExample)
  119. for i=1,self.exampleIds:size(1) do
  120. local exampleId = self.exampleIds[i]
  121. local input = dataset.input[exampleId]
  122. local val = input[splitInfo.splitId]
  123. -- Note: when the feature is not present in the example, the example is droped from all sub-trees.
  124. -- Which means that for most sparse data, a tree cannot reach 100% accuracy...
  125. if val then
  126. if val < splitInfo.splitValue then
  127. leftIdx = leftIdx + 1
  128. splitExampleIds[leftIdx] = exampleId
  129. else
  130. rightIdx = rightIdx + 1
  131. splitExampleIds[nExample-rightIdx+1] = exampleId
  132. end
  133. end
  134. end
  135. local leftExampleIds = splitExampleIds:narrow(1,1,leftIdx)
  136. local rightExampleIds = splitExampleIds:narrow(1,nExample-rightIdx+1,rightIdx)
  137. assert(leftExampleIds:size(1) + rightExampleIds:size(1) <= self.exampleIds:size(1), "Left and right branches contain more data than the parent!")
  138. return leftExampleIds, rightExampleIds
  139. end
  140. -- calls _branch and encapsulates the left and right exampleIds into a TreeStates
  141. function TreeState:branch(splitInfo, dataset)
  142. local leftExampleIds, rightExampleIds = self:_branch(splitInfo, dataset)
  143. return self.new(leftExampleIds), self.new(rightExampleIds)
  144. end
  145. function TreeState:size()
  146. return self.exampleIds:size(1)
  147. end
  148. function TreeState:contains(exampleId)
  149. local found = false
  150. self.exampleIds:apply(function(x)
  151. if x == exampleId then
  152. found = true
  153. end
  154. end)
  155. return found
  156. end