You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GiniState.lua 2.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. local dt = require 'decisiontree._env'
  2. -- used by RandomForestTrainer
  3. local GiniState, parent = torch.class("dt.GiniState", "dt.TreeState", dt)
  4. function GiniState:__init(exampleIds)
  5. parent.__init(self, exampleIds)
  6. self.nPositiveInLeftBranch = 0
  7. self.nPositiveInRightBranch = 0
  8. end
  9. function GiniState:score(dataset)
  10. local dt = require 'decisiontree'
  11. local nPositive = dataset:countPositive(self.exampleIds)
  12. return dt.calculateLogitScore(nPositive, self.exampleIds:size(1))
  13. end
  14. function GiniState:initialize(exampleIdsWithFeature, dataset)
  15. assert(torch.type(exampleIdsWithFeature) == 'torch.LongTensor')
  16. assert(torch.isTypeOf(dataset, 'dt.DataSet'))
  17. self.nPositiveInLeftBranch = dataset:countPositive(exampleIdsWithFeature)
  18. self.nPositiveInRightBranch = 0
  19. self.nExampleInLeftBranch = exampleIdsWithFeature:size(1)
  20. self.nExampleInRightBranch = 0
  21. end
  22. function GiniState:update(exampleId, dataset)
  23. assert(torch.type(exampleId) == 'number')
  24. assert(torch.isTypeOf(dataset, 'dt.DataSet'))
  25. if dataset.target[exampleId] > 0 then
  26. self.nPositiveInLeftBranch = self.nPositiveInLeftBranch - 1
  27. self.nPositiveInRightBranch = self.nPositiveInRightBranch + 1
  28. end
  29. self.nExampleInLeftBranch = self.nExampleInLeftBranch - 1
  30. self.nExampleInRightBranch = self.nExampleInRightBranch + 1
  31. end
  32. function GiniState:computeSplitInfo(splitFeatureId, splitFeatureValue)
  33. local dt = require 'decisiontree'
  34. local gini = dt.computeGini(self.nExampleInLeftBranch, self.nPositiveInLeftBranch, self.nExampleInRightBranch, self.nPositiveInRightBranch)
  35. local splitInfo = {
  36. splitId = assert(splitFeatureId),
  37. splitValue = assert(splitFeatureValue),
  38. leftChildSize = assert(self.nExampleInLeftBranch),
  39. leftPositiveCount = assert(self.nPositiveInLeftBranch),
  40. rightChildSize = assert(self.nExampleInRightBranch),
  41. rightPositiveCount = assert(self.nPositiveInRightBranch),
  42. gini = assert(gini),
  43. splitGain = gini
  44. }
  45. return splitInfo
  46. end