You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark.lua 6.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. local dt = require "decisiontree._env"
  2. local bm = {}
  3. function bm.CartTrainer(opt)
  4. local timer = torch.Timer()
  5. local trainSet, validSet = dt.getSparseDummyData(opt)
  6. print(string.format("CartTrainer: sparse dataset create: %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real))
  7. local cartTrainer = dt.CartTrainer(trainSet, opt.minLeafSize, opt.maxLeafNodes)
  8. local treeState = dt.GiniState(trainSet:getExampleIds())
  9. timer:reset()
  10. local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds)
  11. print(string.format("CartTrainer: train single-thread : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real))
  12. timer:reset()
  13. cartTrainer:featureParallel(opt.nThread)
  14. print(string.format("CartTrainer: setup feature-parallel : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real))
  15. timer:reset()
  16. local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds)
  17. print(string.format("CartTrainer: train feature-parallel : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real))
  18. end
  19. function bm.GradientBoostState(opt)
  20. local trainSet, validSet = dt.getSparseDummyData(opt)
  21. trainSet:initScore()
  22. local treeState = dt.GradientBoostState(trainSet:getExampleIds(), nn.LogitBoostCriterion(false))
  23. local timer = torch.Timer() -- first step also calls SparseTensor:buildIndex()
  24. treeState:findBestSplit(trainSet, trainSet.featureIds, 10, 1, 3)
  25. print(string.format("GradientBoostState: findBestSplit (first) : %f sec", timer:time().real))
  26. timer:reset()
  27. treeState:findBestSplit(trainSet, trainSet.featureIds, 10, 1, 3)
  28. print(string.format("GradientBoostState: findBestSplit (second) : %f sec", timer:time().real))
  29. end
  30. local function file_exists(name)
  31. local f=io.open(name,"r")
  32. if f~=nil then io.close(f) return true else return false end
  33. end
  34. function bm.GradientBoostTrainer(opt)
  35. local trainSet, validSet
  36. if file_exists("/tmp/train.bin") and file_exists("/tmp/valid.bin") then
  37. trainSet = torch.load("/tmp/train.bin")
  38. validSet = torch.load("/tmp/valid.bin")
  39. else
  40. if opt.sparse then
  41. trainSet, validSet = dt.getSparseDummyData(opt)
  42. else
  43. trainSet, validSet = dt.getDenseDummyData(opt)
  44. end
  45. torch.save("/tmp/train.bin", trainSet)
  46. torch.save("/tmp/valid.bin", validSet)
  47. end
  48. local cartTrainer = dt.CartTrainer(trainSet, opt.minLeafSize, opt.maxLeafNodes)
  49. opt.lossFunction = nn.LogitBoostCriterion(false)
  50. opt.treeTrainer = cartTrainer
  51. local forestTrainer = dt.GradientBoostTrainer(opt)
  52. local timer = torch.Timer()
  53. local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds, validSet)
  54. local time = timer:time().real
  55. print(string.format("GradientBoostTrainer: train single-thread : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time))
  56. cartTrainer:featureParallel(opt.nThread)
  57. timer:reset()
  58. local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds, validSet)
  59. local time = timer:time().real
  60. print(string.format("GradientBoostTrainer: train feature-parallel : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time))
  61. end
  62. function bm.RandomForestTrainer(opt)
  63. local trainSet, validSet = dt.getSparseDummyData(opt)
  64. local forestTrainer = dt.RandomForestTrainer(opt)
  65. local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds)
  66. local timer = torch.Timer()
  67. local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds)
  68. local time = timer:time().real
  69. print(string.format("RandomForestTrainer: train single-thread : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time))
  70. timer:reset()
  71. forestTrainer:treeParallel(opt.nThread)
  72. print(string.format("RandomForestTrainer: setup tree-parallel : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real))
  73. timer:reset()
  74. local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds)
  75. local time = timer:time().real
  76. print(string.format("RandomForestTrainer: train tree-parallel : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time))
  77. end
  78. function bm.DFD(opt)
  79. local _ = require 'moses'
  80. local opt = _.clone(opt)
  81. opt.nExample = 200
  82. local trainSet, validSet = dt.getDenseDummyData(opt)
  83. local forestTrainer = dt.RandomForestTrainer(opt)
  84. forestTrainer:treeParallel(opt.nThread)
  85. local timer = torch.Timer()
  86. local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds)
  87. local time = timer:time().real
  88. print(string.format("DFD: train random forest in parallel : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time))
  89. -- benchmark nn.DFD
  90. local input = trainSet.input:sub(1,opt.batchsize)
  91. local dfd = nn.DFD(decisionForest)
  92. dfd:forward(input)
  93. timer:reset()
  94. for i=1,opt.nloop do
  95. dfd:forward(input)
  96. end
  97. print(string.format("DFD: updateOutput : %f samples/sec; %f sec", opt.nloop*opt.batchsize/timer:time().real, timer:time().real))
  98. end
  99. function bm.Sparse2Dense(opt)
  100. local _ = require 'moses'
  101. local opt = _.clone(opt)
  102. opt.nExample = opt.batchsize
  103. local trainSet = dt.getSparseDummyData(opt)
  104. local input = {{},{}}
  105. for i=1,opt.batchsize do
  106. input[1][i] = trainSet.input[i].keys
  107. input[2][i] = trainSet.input[i].values
  108. end
  109. assert(#input[1] == opt.batchsize)
  110. -- benchmark nn.Sparse2Dense
  111. local s2d = nn.Sparse2Dense(torch.LongTensor():range(1,opt.nFeature))
  112. s2d:forward(input)
  113. local timer = torch.Timer()
  114. for i=1,opt.nloop do
  115. s2d:forward(input)
  116. end
  117. print(string.format("Sparse2Dense: updateOutput : %f samples/sec; %f sec", opt.nloop*opt.batchsize/timer:time().real, timer:time().real))
  118. end
  119. function dt.benchmark(benchmarks, opt2)
  120. local opt = {
  121. nExample=10000, nCluster=2, nFeature=1000, overlap=0, nValid=100, -- getSparseDummyData
  122. nTree=20, featureBaggingSize=-1, sparse=true, -- GradientBoostTrainer and RandomForestTrainer
  123. nThread=2, shrinkage=0.1, downsampleRatio=0.1, evalFreq=5, earlyStop=0, -- GradientBoostTrainer
  124. activeRatio=0.5, -- RandomForestTrainer
  125. batchsize=32, nloop=10
  126. }
  127. local _ = require 'moses'
  128. benchmarks = benchmarks or _.keys(bm)
  129. assert(torch.type(benchmarks) == 'table')
  130. for i,benchmark in ipairs(benchmarks) do
  131. local opt1 = _.clone(opt)
  132. for key, value in pairs(opt2 or {}) do
  133. opt1[key] = value
  134. end
  135. opt1.nActive = opt1.nActive or torch.round(opt1.nFeature/10)
  136. opt1.maxLeafNodes = opt1.maxLeafNodes or (opt1.nExample/10)
  137. opt1.minLeafSize = opt1.minLeafSize or (opt1.nExample/100)
  138. assert(torch.type(benchmark) == 'string', benchmark)
  139. assert(bm[benchmark], benchmark)
  140. bm[benchmark](opt1)
  141. end
  142. end