aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/CartTree.lua
blob: c74dfda9ecbb9c653160ef50c75baf0cc2180c2f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
local _ = require "moses"
local dt = require 'decisiontree._env'

-- CART (classification-regression decision tree).
-- The example is always branched to the left when the splitting feature is missing.
local CartTree = torch.class("dt.CartTree", "dt.DecisionTree", dt)

function CartTree:__init(root, branchleft)
   assert(torch.isTypeOf(root, 'dt.CartNode'))
   self.root = root
   self.branchleft = branchleft or function() return true end
end

-- TODO optimize this
function CartTree:score(input, stack, optimized)
   if optimized == true and stack == nil and torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then
      return input.nn.CartTreeFastScore(input, self.root, input.new())
   end
   return self:recursivescore(self.root, input, stack)
end

-- Continuous: if input[node.splitFeatureId] < node.splitFeatureValue then leftNode else rightNode
-- Binary: if input[node.splitFeatureId] == 0 then leftNode else rightNode
-- when stack is provided, it is returned as the third argument containing the stack of nodes from root to leaf
function CartTree:recursivescore(node, input, stack)
   assert(torch.isTypeOf(node, 'dt.CartNode'))

   if stack then
      stack = torch.type(stack) == 'table' and stack or {}
      table.insert(stack, node)
   end

   if not (node.leftChild or node.rightChild) then
      return node.score, node.nodeId, stack
   elseif not node.leftChild then
      return self:recursivescore(node.rightChild, input, stack)
   elseif not node.rightChild then
      return self:recursivescore(node.leftChild, input, stack)
   end

   local splitId = node.splitFeatureId
   local splitVal = node.splitFeatureValue

   if input[splitId] then -- if has key
      local featureVal = input[splitId]
      local nextNode = featureVal < splitVal and node.leftChild or node.rightChild
      return self:recursivescore(nextNode, input, stack)
   end

   -- if feature is missing, branch left
   local nextNode = self.branchleft() and node.leftChild or node.rightChild
   return self:recursivescore(nextNode, input, stack)
end

function CartTree:__tostring__()
   return self.root:recursivetostring()
end

-- expects a stack returned by score
function CartTree:stackToString(stack, input)
   assert(torch.type(stack) == 'table')
   assert(torch.isTypeOf(stack[1], 'dt.CartNode'))

   local res = 'Stack nodes from root to leaf\n'
   for i,node in ipairs(stack) do
      if not (node.leftChild or node.rightChild) then
         res = res .. "score="..node.score .. '\n'
      else
         local istr = ''
         if input then
            istr = '=' .. (input[node.splitFeatureId] or 'nil')
         end
         res = res .. 'input[' .. node.splitFeatureId .. ']' .. istr ..' < ' .. node.splitFeatureValue .. ' ? '
         res = res .. '(' .. ((node.leftChild and node.rightChild) and 'LR' or node.leftChild and 'L' or node.rightChild and 'R' or 'WAT?') .. ') '
         if node.leftChild == stack[i+1] then
            res = res .. 'Left\n'
         elseif node.rightChild == stack[i+1] then
            res = res .. 'Right\n'
         else
            error"stackToString error"
         end
      end
   end
   return res .. #stack .. " nodes"
end

function CartTree:clone()
   return CartTree(self.root:clone(), self.branchleft)
end