mirror of
https://github.com/rspamd/rspamd.git
synced 2024-09-13 23:56:50 +02:00
192 lines
6.6 KiB
Lua
192 lines
6.6 KiB
Lua
local dt = require "decisiontree._env"
|
|
|
|
local TreeState = torch.class("dt.TreeState", dt)
|
|
|
|
-- Holds the state of a subtree during decision tree training.
|
|
-- Also, manages the state of candidate splits
|
|
function TreeState:__init(exampleIds)
|
|
assert(torch.type(exampleIds) == 'torch.LongTensor')
|
|
self.exampleIds = exampleIds
|
|
|
|
self.nExampleInLeftBranch = 0
|
|
self.nExampleInRightBranch = 0
|
|
end
|
|
|
|
-- computes and returns the score of the node based on its examples
|
|
function TreeState:score(dataset)
|
|
error"NotImplemented"
|
|
end
|
|
|
|
|
|
-- Initializes the split-state-updater. Initially all examples are in the left branch.
|
|
-- exampleIdsWithFeature is list of examples to split (those having a particular feature)
|
|
function TreeState:initialize(exampleIdsWithFeature, dataset)
|
|
error"NotImplemented"
|
|
end
|
|
|
|
-- Update the split state. This call has the effect of shifting the example from the left to the right branch.
|
|
function TreeState:update(exampleId, dataset)
|
|
error"NotImplemented"
|
|
end
|
|
|
|
-- Computes the SplitInfo determined by the current split state
|
|
-- @param splitFeatureId the feature id of the split feature
|
|
-- @param splitFeatureValue the feature value of the split feature
|
|
-- @return the SplitInfo determined by the current split state
|
|
function TreeState:computeSplitInfo(splitFeatureId, splitFeatureValue)
|
|
error"NotImplemented"
|
|
end
|
|
|
|
-- bottleneck
|
|
function TreeState:findBestFeatureSplit(dataset, featureId, minLeafSize)
|
|
local dt = require "decisiontree"
|
|
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
|
|
assert(torch.type(featureId) == 'number')
|
|
assert(torch.type(minLeafSize) == 'number')
|
|
|
|
-- all dataset example having this feature, sorted by value
|
|
local featureExampleIds = dataset:getSortedFeature(featureId)
|
|
|
|
local buffer = dt.getBufferTable('TreeState')
|
|
buffer.longtensor = buffer.longtensor or torch.LongTensor()
|
|
local exampleIdsWithFeature = buffer.longtensor
|
|
|
|
-- map and tensor of examples containing feature:
|
|
local exampleMap = {}
|
|
local getExampleFeatureValue
|
|
|
|
local j = 0
|
|
if torch.type(dataset.input) == 'table' then
|
|
exampleIdsWithFeature:resize(self.exampleIds:size())
|
|
self.exampleIds:apply(function(exampleId)
|
|
local input = dataset.input[exampleId]
|
|
input:buildIndex()-- only builds index first time
|
|
if input[featureId] then
|
|
j = j + 1
|
|
exampleIdsWithFeature[j] = exampleId
|
|
exampleMap[exampleId] = j
|
|
end
|
|
end)
|
|
if j == 0 then
|
|
return
|
|
end
|
|
exampleIdsWithFeature:resize(j)
|
|
getExampleFeatureValue = function(exampleId) return dataset.input[exampleId][featureId] end
|
|
else
|
|
exampleIdsWithFeature = self.exampleIds
|
|
self.exampleIds:apply(function(exampleId)
|
|
j = j + 1
|
|
exampleMap[exampleId] = j
|
|
end)
|
|
local featureValues = dataset.input:select(2,featureId)
|
|
getExampleFeatureValue = function(exampleId) return featureValues[exampleId] end
|
|
end
|
|
|
|
|
|
self:initialize(exampleIdsWithFeature, dataset)
|
|
|
|
-- bottleneck
|
|
local bestSplit, previousSplitValue, _tictoc
|
|
for i=featureExampleIds:size(1),1,-1 do -- loop over examples sorted (desc) by feature value
|
|
local exampleId = featureExampleIds[i]
|
|
|
|
local exampleIdx = exampleMap[exampleId]
|
|
if exampleIdx then
|
|
local splitValue = getExampleFeatureValue(exampleId)
|
|
|
|
if previousSplitValue and math.abs(splitValue - previousSplitValue) > dt.EPSILON then
|
|
local splitInfo = self:computeSplitInfo(featureId, previousSplitValue, _tictoc)
|
|
if (splitInfo.leftChildSize >= minLeafSize) and (splitInfo.rightChildSize >= minLeafSize) then
|
|
|
|
if (not bestSplit) or (splitInfo.splitGain < bestSplit.splitGain) then
|
|
_tictoc = bestSplit or {} -- reuse table
|
|
bestSplit = splitInfo
|
|
end
|
|
|
|
end
|
|
end
|
|
|
|
previousSplitValue = splitValue
|
|
|
|
-- bottleneck
|
|
self:update(exampleId, dataset, exampleIdx)
|
|
end
|
|
end
|
|
|
|
return bestSplit
|
|
end
|
|
|
|
-- finds the best split of examples in treeState among featureIds
|
|
function TreeState:findBestSplit(dataset, featureIds, minLeafSize, shardId, nShard)
|
|
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
|
|
assert(torch.type(featureIds) == 'torch.LongTensor')
|
|
assert(torch.type(minLeafSize) == 'number')
|
|
assert(torch.type(shardId) == 'number')
|
|
assert(torch.type(nShard) == 'number')
|
|
|
|
local bestSplit
|
|
for i=1,featureIds:size(1) do
|
|
local featureId = featureIds[i]
|
|
if (nShard <= 1) or ( (featureId % nShard) + 1 == shardId ) then -- feature sharded
|
|
local splitCandidate = self:findBestFeatureSplit(dataset, featureId, minLeafSize)
|
|
if splitCandidate and ((not bestSplit) or (splitCandidate.splitGain < bestSplit.splitGain)) then
|
|
bestSplit = splitCandidate
|
|
end
|
|
end
|
|
end
|
|
|
|
return bestSplit
|
|
end
|
|
|
|
-- Partitions self given a splitInfo table, producing a pair of exampleIds corresponding to the left and right subtrees.
|
|
function TreeState:_branch(splitInfo, dataset)
|
|
local leftIdx, rightIdx = 0, 0
|
|
local nExample = self.exampleIds:size(1)
|
|
local splitExampleIds = torch.LongTensor(nExample)
|
|
|
|
|
|
for i=1,self.exampleIds:size(1) do
|
|
local exampleId = self.exampleIds[i]
|
|
local input = dataset.input[exampleId]
|
|
local val = input[splitInfo.splitId]
|
|
-- Note: when the feature is not present in the example, the example is droped from all sub-trees.
|
|
-- Which means that for most sparse data, a tree cannot reach 100% accuracy...
|
|
if val then
|
|
if val < splitInfo.splitValue then
|
|
leftIdx = leftIdx + 1
|
|
splitExampleIds[leftIdx] = exampleId
|
|
else
|
|
rightIdx = rightIdx + 1
|
|
splitExampleIds[nExample-rightIdx+1] = exampleId
|
|
end
|
|
end
|
|
end
|
|
|
|
local leftExampleIds = splitExampleIds:narrow(1,1,leftIdx)
|
|
local rightExampleIds = splitExampleIds:narrow(1,nExample-rightIdx+1,rightIdx)
|
|
|
|
assert(leftExampleIds:size(1) + rightExampleIds:size(1) <= self.exampleIds:size(1), "Left and right branches contain more data than the parent!")
|
|
return leftExampleIds, rightExampleIds
|
|
end
|
|
|
|
-- calls _branch and encapsulates the left and right exampleIds into a TreeStates
|
|
function TreeState:branch(splitInfo, dataset)
|
|
local leftExampleIds, rightExampleIds = self:_branch(splitInfo, dataset)
|
|
return self.new(leftExampleIds), self.new(rightExampleIds)
|
|
end
|
|
|
|
function TreeState:size()
|
|
return self.exampleIds:size(1)
|
|
end
|
|
|
|
function TreeState:contains(exampleId)
|
|
local found = false
|
|
self.exampleIds:apply(function(x)
|
|
if x == exampleId then
|
|
found = true
|
|
end
|
|
end)
|
|
return found
|
|
end
|
|
|