summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/CartTrainer.lua
blob: 63ae6c14874c7602653fc59719423b7eb216be6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
local dt = require "decisiontree._env"
local _ = require "moses"

local CartTrainer = torch.class("dt.CartTrainer", dt)

-- Generic CART trainer
function CartTrainer:__init(dataset, minLeafSize, maxLeafNodes)
   assert(torch.isTypeOf(dataset, 'dt.DataSet'))
   self.dataset = dataset
   self.minLeafSize = assert(minLeafSize) -- min examples per leaf
   self.maxLeafNodes = assert(maxLeafNodes) -- max leaf nodes in tree

   -- by default, single thread
   self.parallelMode = 'singlethread'
end

function CartTrainer:train(rootTreeState, activeFeatures)
   assert(torch.isTypeOf(rootTreeState, 'dt.TreeState'))
   assert(torch.isTensor(activeFeatures))
   local root = dt.CartNode()
   root.id = 0
   root.score = rootTreeState:score(self.dataset)

   local nleaf = 1

   -- TODO : nodeparallel: parallelize here. The queue is a workqueue.
   local queue = {}
   table.insert(queue, 1, {cartNode=root, treeState=rootTreeState})

   while #queue > 0 and nleaf < self.maxLeafNodes do
      local treeGrowerArgs = table.remove(queue, #queue)
      local currentTreeState = treeGrowerArgs.treeState

      -- Note: if minLeafSize = 1 and maxLeafNode = inf, then each example will be its own leaf...
      if self:hasEnoughTrainingExamplesToSplit(currentTreeState.exampleIds:size(1)) then
         nleaf = self:processNode(nleaf, queue, treeGrowerArgs.cartNode, currentTreeState, activeFeatures)
      end
   end

   -- CartTree with random branching (when feature is missing)
   local branchleft = function() return math.random() < 0.5 end
   return dt.CartTree(root, branchleft), nleaf
end

function CartTrainer:processNode(nleaf, queue, node, treeState, activeFeatures)
   local bestSplit
   if self.parallelMode == 'singlethread' then
      bestSplit = self:findBestSplitForAllFeatures(treeState, activeFeatures)
   elseif self.parallelMode == 'featureparallel' then
      bestSplit = self:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
   else
      error("Unrecognized parallel mode: " .. self.parallelMode)
   end

   if bestSplit then
      local leftTreeState, rightTreeState = treeState:branch(bestSplit, self.dataset)
      assert(bestSplit.leftChildSize + bestSplit.rightChildSize == leftTreeState.exampleIds:size(1) + rightTreeState.exampleIds:size(1), "The left and right subtrees don't match the split found!")
      self:setValuesAndCreateChildrenForNode(node, bestSplit, leftTreeState, rightTreeState, nleaf)

      table.insert(queue, 1, {cartNode=node.leftChild, treeState=leftTreeState})
      table.insert(queue, 1, {cartNode=node.rightChild, treeState=rightTreeState})

      return nleaf + 1
    end

    return nleaf
end

function CartTrainer:findBestSplitForAllFeatures(treeState, activeFeatures)
   local timer = torch.Timer()
   local bestSplit = treeState:findBestSplit(self.dataset, activeFeatures, self.minLeafSize, -1, -1)

   if bestSplit then
      assert(torch.type(bestSplit) == 'table')
   end

   if dt.PROFILE then
      print("findBestSplitForAllFeatures time="..timer:time().real)
   end
   return bestSplit
end

-- Updates the parentNode with the bestSplit information by creates left/right child Nodes.
function CartTrainer:setValuesAndCreateChildrenForNode(parentNode, bestSplit, leftState, rightState, nleaf)
   assert(torch.isTypeOf(parentNode, 'dt.CartNode'))
   assert(torch.type(bestSplit) == 'table')
   assert(torch.isTypeOf(leftState, 'dt.TreeState'))
   assert(torch.isTypeOf(rightState, 'dt.TreeState'))
   assert(torch.type(nleaf) == 'number')

   local leftChild = dt.CartNode()
   leftChild.score = leftState:score(self.dataset)
   leftChild.nodeId = 2 * nleaf - 1

   local rightChild = dt.CartNode()
   rightChild.score = rightState:score(self.dataset)
   rightChild.nodeId = 2 * nleaf

   parentNode.splitFeatureId = bestSplit.splitId
   parentNode.splitFeatureValue = bestSplit.splitValue
   parentNode.leftChild = leftChild
   parentNode.rightChild = rightChild
   parentNode.splitGain = bestSplit.splitGain
end

-- We minimally need 2 * N examples in the parent to satisfy >= N examples per child
function CartTrainer:hasEnoughTrainingExamplesToSplit(count)
   return count >= 2 * self.minLeafSize
end

-- call before training to enable feature-parallelization
function CartTrainer:featureParallel(workPool)
   assert(self.parallelMode == 'singlethread', self.parallelMode)
   self.parallelMode = 'featureparallel'
   self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
   assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))

   -- this deletes all SparseTensor hash maps so that they aren't serialized
   self.dataset:deleteIndex()

   -- require the dt package
   self.workPool:update('require', {libname='decisiontree',varname='dt'})
   -- setup worker store (each worker will have its own copy)
   local store = {
      dataset=self.dataset,
      minLeafSize=self.minLeafSize
   }
   self.workPool:update('storeKeysValues', store)
end

-- feature parallel
function CartTrainer:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
   local timer = torch.Timer()
   local bestSplit
   if treeState.findBestSplitFP then
      bestSplit = treeState:findBestSplitFP(self.dataset, activeFeatures, self.minLeafSize, self.workPool.nThread)
   end

   if not bestSplit then
      for i=1,self.workPool.nThread do
         -- upvalues
         local treeState = treeState
         local shardId = i
         local nShard = self.workPool.nThread
         local featureIds = activeFeatures
         -- closure
         local task = function(store)
            assert(store.dataset)
            assert(store.minLeafSize)
            if treeState.threadInitialize then
               treeState:threadInitialize()
            end

            local bestSplit = treeState:findBestSplit(store.dataset, featureIds, store.minLeafSize, shardId, nShard)
            return bestSplit
         end

         self.workPool:writeup('execute', task)
      end

      for i=1,self.workPool.nThread do
         local taskname, candidateSplit = self.workPool:read()
         assert(taskname == 'execute')
         if candidateSplit then
            if ((not bestSplit) or candidateSplit.splitGain < bestSplit.splitGain) then
               bestSplit = candidateSplit
            end
         end
      end
   end

   if bestSplit then
      assert(torch.type(bestSplit) == 'table')
   end

   if dt.PROFILE then
      print("findBestSplitForAllFeaturesFP time="..timer:time().real)
   end
   return bestSplit
end