1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
|
local dt = require "decisiontree._env"
local _ = require "moses"
local CartTrainer = torch.class("dt.CartTrainer", dt)
-- Generic CART trainer
function CartTrainer:__init(dataset, minLeafSize, maxLeafNodes)
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
self.dataset = dataset
self.minLeafSize = assert(minLeafSize) -- min examples per leaf
self.maxLeafNodes = assert(maxLeafNodes) -- max leaf nodes in tree
-- by default, single thread
self.parallelMode = 'singlethread'
end
function CartTrainer:train(rootTreeState, activeFeatures)
assert(torch.isTypeOf(rootTreeState, 'dt.TreeState'))
assert(torch.isTensor(activeFeatures))
local root = dt.CartNode()
root.id = 0
root.score = rootTreeState:score(self.dataset)
local nleaf = 1
-- TODO : nodeparallel: parallelize here. The queue is a workqueue.
local queue = {}
table.insert(queue, 1, {cartNode=root, treeState=rootTreeState})
while #queue > 0 and nleaf < self.maxLeafNodes do
local treeGrowerArgs = table.remove(queue, #queue)
local currentTreeState = treeGrowerArgs.treeState
-- Note: if minLeafSize = 1 and maxLeafNode = inf, then each example will be its own leaf...
if self:hasEnoughTrainingExamplesToSplit(currentTreeState.exampleIds:size(1)) then
nleaf = self:processNode(nleaf, queue, treeGrowerArgs.cartNode, currentTreeState, activeFeatures)
end
end
-- CartTree with random branching (when feature is missing)
local branchleft = function() return math.random() < 0.5 end
return dt.CartTree(root, branchleft), nleaf
end
function CartTrainer:processNode(nleaf, queue, node, treeState, activeFeatures)
local bestSplit
if self.parallelMode == 'singlethread' then
bestSplit = self:findBestSplitForAllFeatures(treeState, activeFeatures)
elseif self.parallelMode == 'featureparallel' then
bestSplit = self:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
else
error("Unrecognized parallel mode: " .. self.parallelMode)
end
if bestSplit then
local leftTreeState, rightTreeState = treeState:branch(bestSplit, self.dataset)
assert(bestSplit.leftChildSize + bestSplit.rightChildSize == leftTreeState.exampleIds:size(1) + rightTreeState.exampleIds:size(1), "The left and right subtrees don't match the split found!")
self:setValuesAndCreateChildrenForNode(node, bestSplit, leftTreeState, rightTreeState, nleaf)
table.insert(queue, 1, {cartNode=node.leftChild, treeState=leftTreeState})
table.insert(queue, 1, {cartNode=node.rightChild, treeState=rightTreeState})
return nleaf + 1
end
return nleaf
end
function CartTrainer:findBestSplitForAllFeatures(treeState, activeFeatures)
local timer = torch.Timer()
local bestSplit = treeState:findBestSplit(self.dataset, activeFeatures, self.minLeafSize, -1, -1)
if bestSplit then
assert(torch.type(bestSplit) == 'table')
end
if dt.PROFILE then
print("findBestSplitForAllFeatures time="..timer:time().real)
end
return bestSplit
end
-- Updates the parentNode with the bestSplit information by creates left/right child Nodes.
function CartTrainer:setValuesAndCreateChildrenForNode(parentNode, bestSplit, leftState, rightState, nleaf)
assert(torch.isTypeOf(parentNode, 'dt.CartNode'))
assert(torch.type(bestSplit) == 'table')
assert(torch.isTypeOf(leftState, 'dt.TreeState'))
assert(torch.isTypeOf(rightState, 'dt.TreeState'))
assert(torch.type(nleaf) == 'number')
local leftChild = dt.CartNode()
leftChild.score = leftState:score(self.dataset)
leftChild.nodeId = 2 * nleaf - 1
local rightChild = dt.CartNode()
rightChild.score = rightState:score(self.dataset)
rightChild.nodeId = 2 * nleaf
parentNode.splitFeatureId = bestSplit.splitId
parentNode.splitFeatureValue = bestSplit.splitValue
parentNode.leftChild = leftChild
parentNode.rightChild = rightChild
parentNode.splitGain = bestSplit.splitGain
end
-- We minimally need 2 * N examples in the parent to satisfy >= N examples per child
function CartTrainer:hasEnoughTrainingExamplesToSplit(count)
return count >= 2 * self.minLeafSize
end
-- call before training to enable feature-parallelization
function CartTrainer:featureParallel(workPool)
assert(self.parallelMode == 'singlethread', self.parallelMode)
self.parallelMode = 'featureparallel'
self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))
-- this deletes all SparseTensor hash maps so that they aren't serialized
self.dataset:deleteIndex()
-- require the dt package
self.workPool:update('require', {libname='decisiontree',varname='dt'})
-- setup worker store (each worker will have its own copy)
local store = {
dataset=self.dataset,
minLeafSize=self.minLeafSize
}
self.workPool:update('storeKeysValues', store)
end
-- feature parallel
function CartTrainer:findBestSplitForAllFeaturesFP(treeState, activeFeatures)
local timer = torch.Timer()
local bestSplit
if treeState.findBestSplitFP then
bestSplit = treeState:findBestSplitFP(self.dataset, activeFeatures, self.minLeafSize, self.workPool.nThread)
end
if not bestSplit then
for i=1,self.workPool.nThread do
-- upvalues
local treeState = treeState
local shardId = i
local nShard = self.workPool.nThread
local featureIds = activeFeatures
-- closure
local task = function(store)
assert(store.dataset)
assert(store.minLeafSize)
if treeState.threadInitialize then
treeState:threadInitialize()
end
local bestSplit = treeState:findBestSplit(store.dataset, featureIds, store.minLeafSize, shardId, nShard)
return bestSplit
end
self.workPool:writeup('execute', task)
end
for i=1,self.workPool.nThread do
local taskname, candidateSplit = self.workPool:read()
assert(taskname == 'execute')
if candidateSplit then
if ((not bestSplit) or candidateSplit.splitGain < bestSplit.splitGain) then
bestSplit = candidateSplit
end
end
end
end
if bestSplit then
assert(torch.type(bestSplit) == 'table')
end
if dt.PROFILE then
print("findBestSplitForAllFeaturesFP time="..timer:time().real)
end
return bestSplit
end
|