You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DFD.lua 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. -- nn.DFD: Decision Forest Discretizer
  2. -- Takes a dense input and outputs a sparse output.
  3. -- Each node in the forest is its own feature.
  4. -- When a node is traversed, its commensurate feature takes on a value of 1.
  5. -- For all non-traversed nodes, the feature is 0.
  6. local DFD, parent = torch.class("nn.DFD", "nn.Module")
  7. -- TODO: add :type, as the default will convert the long tensors
  8. function DFD:__init(df, onlyLastNode)
  9. parent.__init(self)
  10. if torch.type(df) == 'table' then
  11. self:reconstructFromInfo(df)
  12. else
  13. assert(torch.type(df) == 'dt.DecisionForest')
  14. self.rootIds = torch.LongTensor()
  15. -- nodeId of left and right child nodes
  16. self.leftChild = torch.LongTensor()
  17. self.rightChild = torch.LongTensor()
  18. -- index and value of the feature that splits this node
  19. self.splitFeatureId = torch.LongTensor()
  20. self.splitFeatureValue = torch.Tensor()
  21. -- initialize state given df
  22. self:convertForest2Tensors(df)
  23. self:clearState()
  24. end
  25. self.onlyLastNode = onlyLastNode
  26. self.nTrees = self.rootIds:size(1)
  27. end
  28. -- converts a DecisionForest to efficient tensor representation
  29. function DFD:convertForest2Tensors(df)
  30. self.rootIds:resize(#df.trees)
  31. -- nodeId will map to featureId
  32. local nodeId = 0
  33. -- sets nodeIds of all subnodes
  34. -- and measures the maximum depth over all trees
  35. local function recursiveTree(node, depth)
  36. depth = (depth or 0) + 1
  37. local rdepth = depth
  38. nodeId = nodeId + 1
  39. node._nodeId = nodeId
  40. if node.leftChild then
  41. rdepth = math.max(rdepth, recursiveTree(node.leftChild, depth))
  42. end
  43. if node.rightChild then
  44. rdepth = math.max(rdepth, recursiveTree(node.rightChild, depth))
  45. end
  46. return rdepth
  47. end
  48. -- sum over trees of max depth
  49. self.depth = 0
  50. for i,tree in ipairs(df.trees) do
  51. assert(torch.isTypeOf(tree.root, 'dt.CartNode'))
  52. self.depth = self.depth + recursiveTree(tree.root)
  53. end
  54. -- remove roots from depth
  55. self.depth = self.depth - self.rootIds:size(1)
  56. -- total number of nodes in all trees
  57. self.nNode = nodeId
  58. -- nodeId of left and right child nodes
  59. self.leftChild:resize(self.nNode):fill(-1)
  60. self.rightChild:resize(self.nNode):fill(-1)
  61. -- index and value of the feature that splits this node
  62. self.splitFeatureId:resize(self.nNode):fill(-1)
  63. self.splitFeatureValue:resize(self.nNode):fill(-1)
  64. -- aggregates CartNode attributes to an efficient tensor representation
  65. local function recursiveTree2(node)
  66. local nodeId = assert(node._nodeId)
  67. assert(self.splitFeatureId[nodeId] == -1)
  68. if node.leftChild then
  69. self.leftChild[nodeId] = assert(node.leftChild._nodeId)
  70. recursiveTree2(node.leftChild)
  71. else
  72. self.leftChild[nodeId] = 0
  73. end
  74. if node.rightChild then
  75. self.rightChild[nodeId] = assert(node.rightChild._nodeId)
  76. recursiveTree2(node.rightChild)
  77. else
  78. self.rightChild[nodeId] = 0
  79. end
  80. -- each node splits the dataset on a feature id-value pair
  81. self.splitFeatureId[nodeId] = assert(node.splitFeatureId)
  82. self.splitFeatureValue[nodeId] = assert(node.splitFeatureValue)
  83. end
  84. for i,tree in ipairs(df.trees) do
  85. self.rootIds[i] = assert(tree.root._nodeId)
  86. recursiveTree2(tree.root)
  87. end
  88. assert(self.leftChild:min() >= 0)
  89. assert(self.rightChild:min() >= 0)
  90. end
  91. -- input is a batchsize x inputsize tensor
  92. function DFD:updateOutput(input)
  93. assert(torch.isTensor(input))
  94. assert(input:dim() == 2)
  95. input = input:contiguous()
  96. local batchsize, inputsize = input:size(1), input:size(2)
  97. local size = self.onlyLastNode and self.nTree or self.depth
  98. -- each sample's output keys is resized to maxdepth, which is the maximum size that it can take on
  99. self.outputkeys = self.outputkeys or torch.LongTensor()
  100. self.outputkeys:resize(batchsize, size)
  101. -- values are 1
  102. self.outputvalues = self.outputvalues or input.new()
  103. self.outputvalues:resize(batchsize, size):fill(1)
  104. self.output = input.nn.DFD_computeOutput(self.outputkeys, self.outputvalues, self.rootIds, self.leftChild, self.rightChild, self.splitFeatureId, self.splitFeatureValue, input, self.onlyLastNode)
  105. return self.output
  106. end
  107. function DFD:type(type, tensorCache)
  108. if type then
  109. local info = self:getReconstructionInfo()
  110. for k, v in pairs(info) do
  111. if torch.type(v) ~= 'torch.LongTensor' then
  112. info[k] = nil
  113. end
  114. end
  115. parent.type(self, type, tensorCache)
  116. self:reconstructFromInfo(info)
  117. return self
  118. else
  119. return parent.type(self)
  120. end
  121. end
  122. function DFD:updateGradInput()
  123. error"Not Implemented"
  124. end
  125. function DFD:clearState()
  126. self.output = {{},{}}
  127. self.taskbuffer = {}
  128. self.outputkeys = nil
  129. self.outputvalues = nil
  130. self._range = nil
  131. self._indices = nil
  132. self._mask = nil
  133. end
  134. function DFD:reconstructFromInfo(DFDinfo)
  135. for k,v in pairs(DFDinfo) do
  136. self[k] = v
  137. end
  138. assert(self.leftChild:nDimension() == 1)
  139. assert(self.rightChild:nDimension() == 1)
  140. assert(self.leftChild:size(1) == self.nNode)
  141. assert(self.rightChild:size(1) == self.nNode)
  142. assert(self.leftChild:min() >= 0)
  143. assert(self.rightChild:min() >= 0)
  144. assert(self.splitFeatureId:nDimension() == 1)
  145. assert(self.splitFeatureValue:nDimension() == 1)
  146. assert(self.splitFeatureId:size(1) == self.splitFeatureValue:size(1))
  147. end
  148. function DFD:getReconstructionInfo()
  149. local DFDinfo = {
  150. nNode = self.nNode,
  151. rootIds = self.rootIds,
  152. leftChild = self.leftChild,
  153. rightChild = self.rightChild,
  154. splitFeatureId = self.splitFeatureId,
  155. splitFeatureValue = self.splitFeatureValue,
  156. depth = self.depth
  157. }
  158. return DFDinfo
  159. end