rspamd/contrib/lua-torch/decisiontree/TreeState.lua
2018-05-23 18:14:15 +01:00

192 lines
6.6 KiB
Lua

local dt = require "decisiontree._env"
local TreeState = torch.class("dt.TreeState", dt)
-- Holds the state of a subtree during decision tree training.
-- Also, manages the state of candidate splits
function TreeState:__init(exampleIds)
assert(torch.type(exampleIds) == 'torch.LongTensor')
self.exampleIds = exampleIds
self.nExampleInLeftBranch = 0
self.nExampleInRightBranch = 0
end
-- computes and returns the score of the node based on its examples
function TreeState:score(dataset)
error"NotImplemented"
end
-- Initializes the split-state-updater. Initially all examples are in the left branch.
-- exampleIdsWithFeature is list of examples to split (those having a particular feature)
function TreeState:initialize(exampleIdsWithFeature, dataset)
error"NotImplemented"
end
-- Update the split state. This call has the effect of shifting the example from the left to the right branch.
function TreeState:update(exampleId, dataset)
error"NotImplemented"
end
-- Computes the SplitInfo determined by the current split state
-- @param splitFeatureId the feature id of the split feature
-- @param splitFeatureValue the feature value of the split feature
-- @return the SplitInfo determined by the current split state
function TreeState:computeSplitInfo(splitFeatureId, splitFeatureValue)
error"NotImplemented"
end
-- bottleneck
function TreeState:findBestFeatureSplit(dataset, featureId, minLeafSize)
local dt = require "decisiontree"
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
assert(torch.type(featureId) == 'number')
assert(torch.type(minLeafSize) == 'number')
-- all dataset example having this feature, sorted by value
local featureExampleIds = dataset:getSortedFeature(featureId)
local buffer = dt.getBufferTable('TreeState')
buffer.longtensor = buffer.longtensor or torch.LongTensor()
local exampleIdsWithFeature = buffer.longtensor
-- map and tensor of examples containing feature:
local exampleMap = {}
local getExampleFeatureValue
local j = 0
if torch.type(dataset.input) == 'table' then
exampleIdsWithFeature:resize(self.exampleIds:size())
self.exampleIds:apply(function(exampleId)
local input = dataset.input[exampleId]
input:buildIndex()-- only builds index first time
if input[featureId] then
j = j + 1
exampleIdsWithFeature[j] = exampleId
exampleMap[exampleId] = j
end
end)
if j == 0 then
return
end
exampleIdsWithFeature:resize(j)
getExampleFeatureValue = function(exampleId) return dataset.input[exampleId][featureId] end
else
exampleIdsWithFeature = self.exampleIds
self.exampleIds:apply(function(exampleId)
j = j + 1
exampleMap[exampleId] = j
end)
local featureValues = dataset.input:select(2,featureId)
getExampleFeatureValue = function(exampleId) return featureValues[exampleId] end
end
self:initialize(exampleIdsWithFeature, dataset)
-- bottleneck
local bestSplit, previousSplitValue, _tictoc
for i=featureExampleIds:size(1),1,-1 do -- loop over examples sorted (desc) by feature value
local exampleId = featureExampleIds[i]
local exampleIdx = exampleMap[exampleId]
if exampleIdx then
local splitValue = getExampleFeatureValue(exampleId)
if previousSplitValue and math.abs(splitValue - previousSplitValue) > dt.EPSILON then
local splitInfo = self:computeSplitInfo(featureId, previousSplitValue, _tictoc)
if (splitInfo.leftChildSize >= minLeafSize) and (splitInfo.rightChildSize >= minLeafSize) then
if (not bestSplit) or (splitInfo.splitGain < bestSplit.splitGain) then
_tictoc = bestSplit or {} -- reuse table
bestSplit = splitInfo
end
end
end
previousSplitValue = splitValue
-- bottleneck
self:update(exampleId, dataset, exampleIdx)
end
end
return bestSplit
end
-- finds the best split of examples in treeState among featureIds
function TreeState:findBestSplit(dataset, featureIds, minLeafSize, shardId, nShard)
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
assert(torch.type(featureIds) == 'torch.LongTensor')
assert(torch.type(minLeafSize) == 'number')
assert(torch.type(shardId) == 'number')
assert(torch.type(nShard) == 'number')
local bestSplit
for i=1,featureIds:size(1) do
local featureId = featureIds[i]
if (nShard <= 1) or ( (featureId % nShard) + 1 == shardId ) then -- feature sharded
local splitCandidate = self:findBestFeatureSplit(dataset, featureId, minLeafSize)
if splitCandidate and ((not bestSplit) or (splitCandidate.splitGain < bestSplit.splitGain)) then
bestSplit = splitCandidate
end
end
end
return bestSplit
end
-- Partitions self given a splitInfo table, producing a pair of exampleIds corresponding to the left and right subtrees.
function TreeState:_branch(splitInfo, dataset)
local leftIdx, rightIdx = 0, 0
local nExample = self.exampleIds:size(1)
local splitExampleIds = torch.LongTensor(nExample)
for i=1,self.exampleIds:size(1) do
local exampleId = self.exampleIds[i]
local input = dataset.input[exampleId]
local val = input[splitInfo.splitId]
-- Note: when the feature is not present in the example, the example is droped from all sub-trees.
-- Which means that for most sparse data, a tree cannot reach 100% accuracy...
if val then
if val < splitInfo.splitValue then
leftIdx = leftIdx + 1
splitExampleIds[leftIdx] = exampleId
else
rightIdx = rightIdx + 1
splitExampleIds[nExample-rightIdx+1] = exampleId
end
end
end
local leftExampleIds = splitExampleIds:narrow(1,1,leftIdx)
local rightExampleIds = splitExampleIds:narrow(1,nExample-rightIdx+1,rightIdx)
assert(leftExampleIds:size(1) + rightExampleIds:size(1) <= self.exampleIds:size(1), "Left and right branches contain more data than the parent!")
return leftExampleIds, rightExampleIds
end
-- calls _branch and encapsulates the left and right exampleIds into a TreeStates
function TreeState:branch(splitInfo, dataset)
local leftExampleIds, rightExampleIds = self:_branch(splitInfo, dataset)
return self.new(leftExampleIds), self.new(rightExampleIds)
end
function TreeState:size()
return self.exampleIds:size(1)
end
function TreeState:contains(exampleId)
local found = false
self.exampleIds:apply(function(x)
if x == exampleId then
found = true
end
end)
return found
end