aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/decisiontree/GiniState.lua
blob: 6dfed28452e53363555f8aca0d434f5298d0638d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
local dt = require 'decisiontree._env'

-- used by RandomForestTrainer
local GiniState, parent = torch.class("dt.GiniState", "dt.TreeState", dt)

function GiniState:__init(exampleIds)
   parent.__init(self, exampleIds)
   self.nPositiveInLeftBranch = 0
   self.nPositiveInRightBranch = 0
end

function GiniState:score(dataset)
   local dt = require 'decisiontree'
   local nPositive = dataset:countPositive(self.exampleIds)
   return dt.calculateLogitScore(nPositive, self.exampleIds:size(1))
end

function GiniState:initialize(exampleIdsWithFeature, dataset)
   assert(torch.type(exampleIdsWithFeature) == 'torch.LongTensor')
   assert(torch.isTypeOf(dataset, 'dt.DataSet'))
   self.nPositiveInLeftBranch = dataset:countPositive(exampleIdsWithFeature)
   self.nPositiveInRightBranch = 0

   self.nExampleInLeftBranch = exampleIdsWithFeature:size(1)
   self.nExampleInRightBranch = 0
end

function GiniState:update(exampleId, dataset)
   assert(torch.type(exampleId) == 'number')
   assert(torch.isTypeOf(dataset, 'dt.DataSet'))
   if dataset.target[exampleId] > 0 then
      self.nPositiveInLeftBranch = self.nPositiveInLeftBranch - 1
      self.nPositiveInRightBranch = self.nPositiveInRightBranch + 1
   end

   self.nExampleInLeftBranch = self.nExampleInLeftBranch - 1
   self.nExampleInRightBranch = self.nExampleInRightBranch + 1
end

function GiniState:computeSplitInfo(splitFeatureId, splitFeatureValue)
   local dt = require 'decisiontree'
   local gini = dt.computeGini(self.nExampleInLeftBranch, self.nPositiveInLeftBranch, self.nExampleInRightBranch, self.nPositiveInRightBranch)
   local splitInfo = {
      splitId = assert(splitFeatureId),
      splitValue = assert(splitFeatureValue),
      leftChildSize = assert(self.nExampleInLeftBranch),
      leftPositiveCount = assert(self.nPositiveInLeftBranch),
      rightChildSize = assert(self.nExampleInRightBranch),
      rightPositiveCount = assert(self.nPositiveInRightBranch),
      gini = assert(gini),
      splitGain = gini
   }
   return splitInfo
end