rspamd/contrib/lua-torch/decisiontree/GiniState.lua
2018-05-23 18:14:15 +01:00

54 righe
2.0 KiB
Lua

local dt = require 'decisiontree._env'
-- used by RandomForestTrainer
local GiniState, parent = torch.class("dt.GiniState", "dt.TreeState", dt)
function GiniState:__init(exampleIds)
parent.__init(self, exampleIds)
self.nPositiveInLeftBranch = 0
self.nPositiveInRightBranch = 0
end
function GiniState:score(dataset)
local dt = require 'decisiontree'
local nPositive = dataset:countPositive(self.exampleIds)
return dt.calculateLogitScore(nPositive, self.exampleIds:size(1))
end
function GiniState:initialize(exampleIdsWithFeature, dataset)
assert(torch.type(exampleIdsWithFeature) == 'torch.LongTensor')
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
self.nPositiveInLeftBranch = dataset:countPositive(exampleIdsWithFeature)
self.nPositiveInRightBranch = 0
self.nExampleInLeftBranch = exampleIdsWithFeature:size(1)
self.nExampleInRightBranch = 0
end
function GiniState:update(exampleId, dataset)
assert(torch.type(exampleId) == 'number')
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
if dataset.target[exampleId] > 0 then
self.nPositiveInLeftBranch = self.nPositiveInLeftBranch - 1
self.nPositiveInRightBranch = self.nPositiveInRightBranch + 1
end
self.nExampleInLeftBranch = self.nExampleInLeftBranch - 1
self.nExampleInRightBranch = self.nExampleInRightBranch + 1
end
function GiniState:computeSplitInfo(splitFeatureId, splitFeatureValue)
local dt = require 'decisiontree'
local gini = dt.computeGini(self.nExampleInLeftBranch, self.nPositiveInLeftBranch, self.nExampleInRightBranch, self.nPositiveInRightBranch)
local splitInfo = {
splitId = assert(splitFeatureId),
splitValue = assert(splitFeatureValue),
leftChildSize = assert(self.nExampleInLeftBranch),
leftPositiveCount = assert(self.nPositiveInLeftBranch),
rightChildSize = assert(self.nExampleInRightBranch),
rightPositiveCount = assert(self.nPositiveInRightBranch),
gini = assert(gini),
splitGain = gini
}
return splitInfo
end