You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientBoostState.lua 2.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. local dt = require 'decisiontree._env'
  2. local GradientBoostState, parent = torch.class("dt.GradientBoostState", "dt.TreeState", dt)
  3. function GradientBoostState:__init(exampleIds, gradInput, hessInput)
  4. parent.__init(self, exampleIds)
  5. self.gradInput = gradInput
  6. self.hessInput = hessInput
  7. end
  8. function GradientBoostState:score(dataset)
  9. local dt = require 'decisiontree'
  10. local gradInput = self.gradInput:index(1, self.exampleIds)
  11. local hessInput = self.hessInput:index(1, self.exampleIds)
  12. return dt.computeNewtonScore(gradInput:sum(), hessInput:sum())
  13. end
  14. -- calls _branch and encapsulates the left and right exampleIds into a TreeStates
  15. function GradientBoostState:branch(splitInfo, dataset)
  16. local leftExampleIds, rightExampleIds = self:_branch(splitInfo, dataset)
  17. return self.new(leftExampleIds, self.gradInput, self.hessInput), self.new(rightExampleIds, self.gradInput, self.hessInput)
  18. end
  19. -- Partitions self given a splitInfo table, producing a pair of exampleIds corresponding to the left and right subtrees.
  20. function GradientBoostState:_branch(splitInfo, dataset)
  21. local input = dataset.input
  22. -- if the input is dense, we can use the optimized version
  23. if torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then
  24. return input.nn.GBDT_branch(splitInfo, input, self.exampleIds)
  25. end
  26. return parent._branch(self, splitInfo, dataset)
  27. end
  28. -- The following methods are supersets of each other. You can comment out them to re-use the lua
  29. -- version with just the provided core optimized
  30. -- THIS ONE CANNOT BE COMMENTED OUT
  31. function GradientBoostState:findBestFeatureSplit(dataset, featureId, minLeafSize)
  32. local ret = self.hessInput.nn.GBDT_findBestFeatureSplit(self.exampleIds, dataset, featureId, minLeafSize, self.gradInput, self.hessInput)
  33. return ret
  34. end
  35. -- finds the best split of examples in treeState among featureIds
  36. function GradientBoostState:findBestSplit(dataset, featureIds, minLeafSize, shardId, nShard)
  37. local ret = self.hessInput.nn.GBDT_findBestSplit(self.exampleIds, dataset, featureIds, minLeafSize, shardId, nShard, self.gradInput, self.hessInput)
  38. return ret
  39. end
  40. -- finds the best split like the previous one, but performs feature parallelism. Note that the
  41. -- optimization is only applied if the input is dense
  42. function GradientBoostState:findBestSplitFP(dataset, featureIds, minLeafSize, nThread)
  43. local input = dataset.input
  44. if torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then
  45. local ret = self.hessInput.nn.GBDT_findBestSplitFP(self.exampleIds, dataset, featureIds, minLeafSize, self.gradInput, self.hessInput, nThread)
  46. return ret
  47. end
  48. end