123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- local dt = require "decisiontree._env"
-
- local TreeState = torch.class("dt.TreeState", dt)
-
- -- Holds the state of a subtree during decision tree training.
- -- Also, manages the state of candidate splits
- function TreeState:__init(exampleIds)
- assert(torch.type(exampleIds) == 'torch.LongTensor')
- self.exampleIds = exampleIds
-
- self.nExampleInLeftBranch = 0
- self.nExampleInRightBranch = 0
- end
-
- -- computes and returns the score of the node based on its examples
- function TreeState:score(dataset)
- error"NotImplemented"
- end
-
-
- -- Initializes the split-state-updater. Initially all examples are in the left branch.
- -- exampleIdsWithFeature is list of examples to split (those having a particular feature)
- function TreeState:initialize(exampleIdsWithFeature, dataset)
- error"NotImplemented"
- end
-
- -- Update the split state. This call has the effect of shifting the example from the left to the right branch.
- function TreeState:update(exampleId, dataset)
- error"NotImplemented"
- end
-
- -- Computes the SplitInfo determined by the current split state
- -- @param splitFeatureId the feature id of the split feature
- -- @param splitFeatureValue the feature value of the split feature
- -- @return the SplitInfo determined by the current split state
- function TreeState:computeSplitInfo(splitFeatureId, splitFeatureValue)
- error"NotImplemented"
- end
-
- -- bottleneck
- function TreeState:findBestFeatureSplit(dataset, featureId, minLeafSize)
- local dt = require "decisiontree"
- assert(torch.isTypeOf(dataset, 'dt.DataSet'))
- assert(torch.type(featureId) == 'number')
- assert(torch.type(minLeafSize) == 'number')
-
- -- all dataset example having this feature, sorted by value
- local featureExampleIds = dataset:getSortedFeature(featureId)
-
- local buffer = dt.getBufferTable('TreeState')
- buffer.longtensor = buffer.longtensor or torch.LongTensor()
- local exampleIdsWithFeature = buffer.longtensor
-
- -- map and tensor of examples containing feature:
- local exampleMap = {}
- local getExampleFeatureValue
-
- local j = 0
- if torch.type(dataset.input) == 'table' then
- exampleIdsWithFeature:resize(self.exampleIds:size())
- self.exampleIds:apply(function(exampleId)
- local input = dataset.input[exampleId]
- input:buildIndex()-- only builds index first time
- if input[featureId] then
- j = j + 1
- exampleIdsWithFeature[j] = exampleId
- exampleMap[exampleId] = j
- end
- end)
- if j == 0 then
- return
- end
- exampleIdsWithFeature:resize(j)
- getExampleFeatureValue = function(exampleId) return dataset.input[exampleId][featureId] end
- else
- exampleIdsWithFeature = self.exampleIds
- self.exampleIds:apply(function(exampleId)
- j = j + 1
- exampleMap[exampleId] = j
- end)
- local featureValues = dataset.input:select(2,featureId)
- getExampleFeatureValue = function(exampleId) return featureValues[exampleId] end
- end
-
-
- self:initialize(exampleIdsWithFeature, dataset)
-
- -- bottleneck
- local bestSplit, previousSplitValue, _tictoc
- for i=featureExampleIds:size(1),1,-1 do -- loop over examples sorted (desc) by feature value
- local exampleId = featureExampleIds[i]
-
- local exampleIdx = exampleMap[exampleId]
- if exampleIdx then
- local splitValue = getExampleFeatureValue(exampleId)
-
- if previousSplitValue and math.abs(splitValue - previousSplitValue) > dt.EPSILON then
- local splitInfo = self:computeSplitInfo(featureId, previousSplitValue, _tictoc)
- if (splitInfo.leftChildSize >= minLeafSize) and (splitInfo.rightChildSize >= minLeafSize) then
-
- if (not bestSplit) or (splitInfo.splitGain < bestSplit.splitGain) then
- _tictoc = bestSplit or {} -- reuse table
- bestSplit = splitInfo
- end
-
- end
- end
-
- previousSplitValue = splitValue
-
- -- bottleneck
- self:update(exampleId, dataset, exampleIdx)
- end
- end
-
- return bestSplit
- end
-
- -- finds the best split of examples in treeState among featureIds
- function TreeState:findBestSplit(dataset, featureIds, minLeafSize, shardId, nShard)
- assert(torch.isTypeOf(dataset, 'dt.DataSet'))
- assert(torch.type(featureIds) == 'torch.LongTensor')
- assert(torch.type(minLeafSize) == 'number')
- assert(torch.type(shardId) == 'number')
- assert(torch.type(nShard) == 'number')
-
- local bestSplit
- for i=1,featureIds:size(1) do
- local featureId = featureIds[i]
- if (nShard <= 1) or ( (featureId % nShard) + 1 == shardId ) then -- feature sharded
- local splitCandidate = self:findBestFeatureSplit(dataset, featureId, minLeafSize)
- if splitCandidate and ((not bestSplit) or (splitCandidate.splitGain < bestSplit.splitGain)) then
- bestSplit = splitCandidate
- end
- end
- end
-
- return bestSplit
- end
-
- -- Partitions self given a splitInfo table, producing a pair of exampleIds corresponding to the left and right subtrees.
- function TreeState:_branch(splitInfo, dataset)
- local leftIdx, rightIdx = 0, 0
- local nExample = self.exampleIds:size(1)
- local splitExampleIds = torch.LongTensor(nExample)
-
-
- for i=1,self.exampleIds:size(1) do
- local exampleId = self.exampleIds[i]
- local input = dataset.input[exampleId]
- local val = input[splitInfo.splitId]
- -- Note: when the feature is not present in the example, the example is droped from all sub-trees.
- -- Which means that for most sparse data, a tree cannot reach 100% accuracy...
- if val then
- if val < splitInfo.splitValue then
- leftIdx = leftIdx + 1
- splitExampleIds[leftIdx] = exampleId
- else
- rightIdx = rightIdx + 1
- splitExampleIds[nExample-rightIdx+1] = exampleId
- end
- end
- end
-
- local leftExampleIds = splitExampleIds:narrow(1,1,leftIdx)
- local rightExampleIds = splitExampleIds:narrow(1,nExample-rightIdx+1,rightIdx)
-
- assert(leftExampleIds:size(1) + rightExampleIds:size(1) <= self.exampleIds:size(1), "Left and right branches contain more data than the parent!")
- return leftExampleIds, rightExampleIds
- end
-
- -- calls _branch and encapsulates the left and right exampleIds into a TreeStates
- function TreeState:branch(splitInfo, dataset)
- local leftExampleIds, rightExampleIds = self:_branch(splitInfo, dataset)
- return self.new(leftExampleIds), self.new(rightExampleIds)
- end
-
- function TreeState:size()
- return self.exampleIds:size(1)
- end
-
- function TreeState:contains(exampleId)
- local found = false
- self.exampleIds:apply(function(x)
- if x == exampleId then
- found = true
- end
- end)
- return found
- end
|