You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CartTree.lua 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. local _ = require "moses"
  2. local dt = require 'decisiontree._env'
  3. -- CART (classification-regression decision tree).
  4. -- The example is always branched to the left when the splitting feature is missing.
  5. local CartTree = torch.class("dt.CartTree", "dt.DecisionTree", dt)
  6. function CartTree:__init(root, branchleft)
  7. assert(torch.isTypeOf(root, 'dt.CartNode'))
  8. self.root = root
  9. self.branchleft = branchleft or function() return true end
  10. end
  11. -- TODO optimize this
  12. function CartTree:score(input, stack, optimized)
  13. if optimized == true and stack == nil and torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then
  14. return input.nn.CartTreeFastScore(input, self.root, input.new())
  15. end
  16. return self:recursivescore(self.root, input, stack)
  17. end
  18. -- Continuous: if input[node.splitFeatureId] < node.splitFeatureValue then leftNode else rightNode
  19. -- Binary: if input[node.splitFeatureId] == 0 then leftNode else rightNode
  20. -- when stack is provided, it is returned as the third argument containing the stack of nodes from root to leaf
  21. function CartTree:recursivescore(node, input, stack)
  22. assert(torch.isTypeOf(node, 'dt.CartNode'))
  23. if stack then
  24. stack = torch.type(stack) == 'table' and stack or {}
  25. table.insert(stack, node)
  26. end
  27. if not (node.leftChild or node.rightChild) then
  28. return node.score, node.nodeId, stack
  29. elseif not node.leftChild then
  30. return self:recursivescore(node.rightChild, input, stack)
  31. elseif not node.rightChild then
  32. return self:recursivescore(node.leftChild, input, stack)
  33. end
  34. local splitId = node.splitFeatureId
  35. local splitVal = node.splitFeatureValue
  36. if input[splitId] then -- if has key
  37. local featureVal = input[splitId]
  38. local nextNode = featureVal < splitVal and node.leftChild or node.rightChild
  39. return self:recursivescore(nextNode, input, stack)
  40. end
  41. -- if feature is missing, branch left
  42. local nextNode = self.branchleft() and node.leftChild or node.rightChild
  43. return self:recursivescore(nextNode, input, stack)
  44. end
  45. function CartTree:__tostring__()
  46. return self.root:recursivetostring()
  47. end
  48. -- expects a stack returned by score
  49. function CartTree:stackToString(stack, input)
  50. assert(torch.type(stack) == 'table')
  51. assert(torch.isTypeOf(stack[1], 'dt.CartNode'))
  52. local res = 'Stack nodes from root to leaf\n'
  53. for i,node in ipairs(stack) do
  54. if not (node.leftChild or node.rightChild) then
  55. res = res .. "score="..node.score .. '\n'
  56. else
  57. local istr = ''
  58. if input then
  59. istr = '=' .. (input[node.splitFeatureId] or 'nil')
  60. end
  61. res = res .. 'input[' .. node.splitFeatureId .. ']' .. istr ..' < ' .. node.splitFeatureValue .. ' ? '
  62. res = res .. '(' .. ((node.leftChild and node.rightChild) and 'LR' or node.leftChild and 'L' or node.rightChild and 'R' or 'WAT?') .. ') '
  63. if node.leftChild == stack[i+1] then
  64. res = res .. 'Left\n'
  65. elseif node.rightChild == stack[i+1] then
  66. res = res .. 'Right\n'
  67. else
  68. error"stackToString error"
  69. end
  70. end
  71. end
  72. return res .. #stack .. " nodes"
  73. end
  74. function CartTree:clone()
  75. return CartTree(self.root:clone(), self.branchleft)
  76. end