You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

RandomForestTrainer.lua 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. local dt = require "decisiontree._env"
  2. local RandomForestTrainer = torch.class("dt.RandomForestTrainer", dt)
  3. function RandomForestTrainer:__init(opt)
  4. assert(torch.type(opt.nTree) == 'number')
  5. assert(opt.nTree > 0)
  6. self.nTree = opt.nTree
  7. -- max number of leaf nodes per tree
  8. assert(torch.type(opt.maxLeafNodes) == 'number')
  9. assert(opt.maxLeafNodes > 0)
  10. self.maxLeafNodes = opt.maxLeafNodes
  11. -- min number of examples per leaf
  12. assert(torch.type(opt.minLeafSize) == 'number')
  13. assert(opt.minLeafSize > 0)
  14. self.minLeafSize = opt.minLeafSize
  15. -- when non-positive, defaults to sqrt(#feature)
  16. assert(torch.type(opt.featureBaggingSize) == 'number')
  17. self.featureBaggingSize = opt.featureBaggingSize
  18. assert(torch.type(opt.activeRatio) == 'number')
  19. assert(opt.activeRatio > 0)
  20. self.activeRatio = opt.activeRatio
  21. -- default parallelization is singlethread
  22. self.parallelMode = 'singlethread'
  23. end
  24. -- Train a DecisionForest
  25. function RandomForestTrainer:train(trainSet, featureIds, verbose)
  26. assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
  27. assert(torch.type(featureIds) == 'torch.LongTensor')
  28. if verbose then print(string.format("Begin training Decision Forest with %d trees", self.nTree)) end
  29. local weight = torch.Tensor(self.nTree):fill(1 / self.nTree) -- RF uses uniform weights
  30. local trees
  31. if self.parallelMode == 'singlethread' then
  32. trees = self:trainTrees(trainSet, featureIds, verbose)
  33. elseif self.parallelMode == 'treeparallel' then
  34. trainSet:deleteIndex() -- prevents serialization bottleneck
  35. trees = self:trainTreesTP(trainSet, featureIds, verbose)
  36. else
  37. error("Unrecognized parallel mode: " .. self.parallelMode)
  38. end
  39. if verbose then print(string.format("Successfully trained %d trees", #trees)) end
  40. -- set bias
  41. local bias = 0;
  42. for i, tree in ipairs(trees) do
  43. bias = bias + tree.root.score * weight[i]
  44. end
  45. return dt.DecisionForest(trees, weight, bias)
  46. end
  47. function RandomForestTrainer:trainTrees(trainSet, featureIds, verbose)
  48. -- the same CartTrainer will be used for each tree
  49. local cartTrainer = dt.CartTrainer(trainSet, self.minLeafSize, self.maxLeafNodes)
  50. local trees = {}
  51. for treeId=1,self.nTree do
  52. -- Train a CartTree
  53. local tree = self.trainTree(cartTrainer, featureIds, self.featureBaggingSize, self.activeRatio, treeId, verbose)
  54. table.insert(trees, tree)
  55. end
  56. return trees
  57. end
  58. -- static function that returns a cartTree
  59. function RandomForestTrainer.trainTree(cartTrainer, featureIds, baggingSize, activeRatio, treeId, verbose)
  60. assert(torch.isTypeOf(cartTrainer, 'dt.CartTrainer'))
  61. assert(torch.type(featureIds) == 'torch.LongTensor')
  62. local baggingSize = baggingSize > 0 and baggingSize or torch.round(math.sqrt(featureIds:size(1)))
  63. if verbose then
  64. print(string.format("Tree %d: Creating features bootstrap sample with baggingSize %d, nFeatures %d", treeId, baggingSize, featureIds:size(1)))
  65. end
  66. local trainSet = cartTrainer.dataset
  67. -- sample boot strap features
  68. local baggingIndices = torch.LongTensor(baggingSize):random(1,featureIds:size(1))
  69. local activeFeatures = featureIds:index(1, baggingIndices)
  70. -- sample boot strap examples
  71. local sampleSize = torch.round(trainSet:size() * activeRatio)
  72. if verbose then print(string.format("Creating bootstrap sample created of size %d", sampleSize)) end
  73. baggingIndices:resize(sampleSize):random(1,trainSet:size())
  74. local bootStrapExampleIds = torch.LongTensor()
  75. bootStrapExampleIds:index(trainSet:getExampleIds(), 1, baggingIndices)
  76. local cartTree = cartTrainer:train(dt.GiniState(bootStrapExampleIds), activeFeatures)
  77. if verbose then print(string.format("Complete processing tree number %d", treeId)) end
  78. return cartTree
  79. end
  80. -- call before training to enable tree-level parallelization
  81. function RandomForestTrainer:treeParallel(workPool)
  82. assert(self.parallelMode == 'singlethread', self.parallelMode)
  83. self.parallelMode = 'treeparallel'
  84. self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
  85. assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))
  86. -- require the dt package
  87. self.workPool:update('require', {libname='decisiontree',varname='dt'})
  88. end
  89. -- TP is for tree parallel (not toilet paper)
  90. function RandomForestTrainer:trainTreesTP(trainSet, featureIds, verbose)
  91. assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
  92. assert(torch.type(featureIds) == 'torch.LongTensor')
  93. local minLeafSize = self.minLeafSize
  94. local maxLeafNodes = self.maxLeafNodes
  95. -- setup worker store (each worker will have its own cartTrainer)
  96. self.workPool:updateup('execute', function(store)
  97. local dt = require 'decisiontree'
  98. store.cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes)
  99. store.featureIds = featureIds
  100. end)
  101. for treeId=1,self.nTree do
  102. -- upvalues
  103. local baggingSize = self.featureBaggingSize
  104. local activeRatio = self.activeRatio
  105. -- task closure that will be executed in worker-thread
  106. local function trainTreeTask(store)
  107. local dt = require 'decisiontree'
  108. return dt.RandomForestTrainer.trainTree(store.cartTrainer, store.featureIds, baggingSize, activeRatio, treeId, verbose)
  109. end
  110. self.workPool:writeup('execute', trainTreeTask)
  111. end
  112. local trees = {}
  113. for treeId=1,self.nTree do
  114. local taskname, tree = self.workPool:read()
  115. assert(taskname=='execute')
  116. assert(torch.isTypeOf(tree, 'dt.CartTree'))
  117. table.insert(trees, tree)
  118. end
  119. return trees
  120. end
  121. function RandomForestTrainer:getName()
  122. return string.format(
  123. "randomforest-aRatio-%4.2f-maxLeaf-%d-minExample-%d-nTree-%d",
  124. self.activeRatio, self.maxLeafNodes, self.minLeafSize, self.nTree
  125. )
  126. end