aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/CartNode.lua
blob: e6a85006cbc1b58f7ff5af1e4dc422edaaebaf2d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
local dt = require 'decisiontree._env'
local CartNode = torch.class("dt.CartNode", dt)

function CartNode:__init(nodeId, leftChild, rightChild, splitFeatureId, splitFeatureValue, score, splitGain)
   self.nodeId = nodeId or 0
   self.leftChild = leftChild
   self.rightChild = rightChild
   self.splitFeatureId = splitFeatureId or -1
   self.splitFeatureValue = splitFeatureValue or 0
   self.score = score or 0
   self.splitGain = splitGain
end

function CartNode:__tostring__()
   return self:recursivetostring()
end

function CartNode:recursivetostring(indent)
   indent = indent or ' '

   -- Is this a leaf node?
   local res = ''
   if not (self.leftChild or self.rightChild) then
      res = res .. self.score .. '\n'
   else
      -- Print the criteria
      res = res .. 'input[' .. self.splitFeatureId .. '] <' .. self.splitFeatureValue .. '?\n'

      -- Print the branches
      if self.leftChild then
         res = res .. indent .. 'True->' .. self.leftChild:recursivetostring(indent .. '  ')
      end
      if self.rightChild then
         res = res .. indent .. 'False->' .. self.rightChild:recursivetostring(indent .. '  ')
      end
   end
   return res
end

function CartNode:clone()
   return CartNode(self.nodeId, self.leftChild, self.rightChild, self.splitFeatureId, self.splitFeatureValue, self.score, self.splitGain)
end