aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/RandomForestTrainer.lua
blob: 41040b25b2181b9e938e2e77b7b68c0f8f771ac1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
local dt = require "decisiontree._env"

local RandomForestTrainer = torch.class("dt.RandomForestTrainer", dt)

function RandomForestTrainer:__init(opt)
   assert(torch.type(opt.nTree) == 'number')
   assert(opt.nTree > 0)
   self.nTree = opt.nTree
   -- max number of leaf nodes per tree
   assert(torch.type(opt.maxLeafNodes) == 'number')
   assert(opt.maxLeafNodes > 0)
   self.maxLeafNodes = opt.maxLeafNodes
   -- min number of examples per leaf
   assert(torch.type(opt.minLeafSize) == 'number')
   assert(opt.minLeafSize > 0)
   self.minLeafSize = opt.minLeafSize

   -- when non-positive, defaults to sqrt(#feature)
   assert(torch.type(opt.featureBaggingSize) == 'number')
   self.featureBaggingSize = opt.featureBaggingSize

   assert(torch.type(opt.activeRatio) == 'number')
   assert(opt.activeRatio > 0)
   self.activeRatio = opt.activeRatio

   -- default parallelization is singlethread
   self.parallelMode = 'singlethread'
end

-- Train a DecisionForest
function RandomForestTrainer:train(trainSet, featureIds, verbose)
   assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
   assert(torch.type(featureIds) == 'torch.LongTensor')

   if verbose then print(string.format("Begin training Decision Forest with %d trees", self.nTree)) end

   local weight = torch.Tensor(self.nTree):fill(1 / self.nTree) -- RF uses uniform weights

   local trees
   if self.parallelMode == 'singlethread' then
      trees = self:trainTrees(trainSet, featureIds, verbose)
   elseif self.parallelMode == 'treeparallel' then
      trainSet:deleteIndex() -- prevents serialization bottleneck
      trees = self:trainTreesTP(trainSet, featureIds, verbose)
   else
      error("Unrecognized parallel mode: " .. self.parallelMode)
   end

   if verbose then print(string.format("Successfully trained %d trees", #trees)) end

   -- set bias
   local bias = 0;
   for i, tree in ipairs(trees) do
      bias = bias + tree.root.score * weight[i]
   end

   return dt.DecisionForest(trees, weight, bias)
end

function RandomForestTrainer:trainTrees(trainSet, featureIds, verbose)

   -- the same CartTrainer will be used for each tree
   local cartTrainer = dt.CartTrainer(trainSet, self.minLeafSize, self.maxLeafNodes)

   local trees = {}
   for treeId=1,self.nTree do
      -- Train a CartTree
      local tree = self.trainTree(cartTrainer, featureIds, self.featureBaggingSize, self.activeRatio, treeId, verbose)
      table.insert(trees, tree)
   end
   return trees
end

-- static function that returns a cartTree
function RandomForestTrainer.trainTree(cartTrainer, featureIds, baggingSize, activeRatio, treeId, verbose)
   assert(torch.isTypeOf(cartTrainer, 'dt.CartTrainer'))
   assert(torch.type(featureIds) == 'torch.LongTensor')
   local baggingSize = baggingSize > 0 and baggingSize or torch.round(math.sqrt(featureIds:size(1)))

   if verbose then
      print(string.format("Tree %d: Creating features bootstrap sample with baggingSize %d, nFeatures %d", treeId, baggingSize, featureIds:size(1)))
   end

   local trainSet = cartTrainer.dataset

   -- sample boot strap features
   local baggingIndices = torch.LongTensor(baggingSize):random(1,featureIds:size(1))
   local activeFeatures = featureIds:index(1, baggingIndices)

    -- sample boot strap examples
   local sampleSize = torch.round(trainSet:size() * activeRatio)
   if verbose then print(string.format("Creating bootstrap sample created of size %d", sampleSize)) end

   baggingIndices:resize(sampleSize):random(1,trainSet:size())
   local bootStrapExampleIds = torch.LongTensor()
   bootStrapExampleIds:index(trainSet:getExampleIds(), 1, baggingIndices)

   local cartTree = cartTrainer:train(dt.GiniState(bootStrapExampleIds), activeFeatures)

   if verbose then print(string.format("Complete processing tree number %d", treeId)) end

   return cartTree
end

-- call before training to enable tree-level parallelization
function RandomForestTrainer:treeParallel(workPool)
   assert(self.parallelMode == 'singlethread', self.parallelMode)
   self.parallelMode = 'treeparallel'
   self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
   assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))

   -- require the dt package
   self.workPool:update('require', {libname='decisiontree',varname='dt'})
end

-- TP is for tree parallel (not toilet paper)
function RandomForestTrainer:trainTreesTP(trainSet, featureIds, verbose)
   assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
   assert(torch.type(featureIds) == 'torch.LongTensor')
   local minLeafSize = self.minLeafSize
   local maxLeafNodes = self.maxLeafNodes

   -- setup worker store (each worker will have its own cartTrainer)
   self.workPool:updateup('execute', function(store)
      local dt = require 'decisiontree'

      store.cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes)
      store.featureIds = featureIds
   end)

   for treeId=1,self.nTree do
      -- upvalues
      local baggingSize = self.featureBaggingSize
      local activeRatio = self.activeRatio
      -- task closure that will be executed in worker-thread
      local function trainTreeTask(store)
         local dt = require 'decisiontree'
         return dt.RandomForestTrainer.trainTree(store.cartTrainer, store.featureIds, baggingSize, activeRatio, treeId, verbose)
      end
      self.workPool:writeup('execute', trainTreeTask)
   end

   local trees = {}
   for treeId=1,self.nTree do
      local taskname, tree = self.workPool:read()
      assert(taskname=='execute')
      assert(torch.isTypeOf(tree, 'dt.CartTree'))
      table.insert(trees, tree)
   end
   return trees
end

function RandomForestTrainer:getName()
   return string.format(
      "randomforest-aRatio-%4.2f-maxLeaf-%d-minExample-%d-nTree-%d",
      self.activeRatio, self.maxLeafNodes, self.minLeafSize, self.nTree
   )
end