peilaus alkaen
https://github.com/rspamd/rspamd.git
synced 2024-07-29 20:17:47 +02:00
54 rivejä
2.0 KiB
Lua
54 rivejä
2.0 KiB
Lua
local dt = require 'decisiontree._env'
|
|
|
|
-- used by RandomForestTrainer
|
|
local GiniState, parent = torch.class("dt.GiniState", "dt.TreeState", dt)
|
|
|
|
function GiniState:__init(exampleIds)
|
|
parent.__init(self, exampleIds)
|
|
self.nPositiveInLeftBranch = 0
|
|
self.nPositiveInRightBranch = 0
|
|
end
|
|
|
|
function GiniState:score(dataset)
|
|
local dt = require 'decisiontree'
|
|
local nPositive = dataset:countPositive(self.exampleIds)
|
|
return dt.calculateLogitScore(nPositive, self.exampleIds:size(1))
|
|
end
|
|
|
|
function GiniState:initialize(exampleIdsWithFeature, dataset)
|
|
assert(torch.type(exampleIdsWithFeature) == 'torch.LongTensor')
|
|
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
|
|
self.nPositiveInLeftBranch = dataset:countPositive(exampleIdsWithFeature)
|
|
self.nPositiveInRightBranch = 0
|
|
|
|
self.nExampleInLeftBranch = exampleIdsWithFeature:size(1)
|
|
self.nExampleInRightBranch = 0
|
|
end
|
|
|
|
function GiniState:update(exampleId, dataset)
|
|
assert(torch.type(exampleId) == 'number')
|
|
assert(torch.isTypeOf(dataset, 'dt.DataSet'))
|
|
if dataset.target[exampleId] > 0 then
|
|
self.nPositiveInLeftBranch = self.nPositiveInLeftBranch - 1
|
|
self.nPositiveInRightBranch = self.nPositiveInRightBranch + 1
|
|
end
|
|
|
|
self.nExampleInLeftBranch = self.nExampleInLeftBranch - 1
|
|
self.nExampleInRightBranch = self.nExampleInRightBranch + 1
|
|
end
|
|
|
|
function GiniState:computeSplitInfo(splitFeatureId, splitFeatureValue)
|
|
local dt = require 'decisiontree'
|
|
local gini = dt.computeGini(self.nExampleInLeftBranch, self.nPositiveInLeftBranch, self.nExampleInRightBranch, self.nPositiveInRightBranch)
|
|
local splitInfo = {
|
|
splitId = assert(splitFeatureId),
|
|
splitValue = assert(splitFeatureValue),
|
|
leftChildSize = assert(self.nExampleInLeftBranch),
|
|
leftPositiveCount = assert(self.nPositiveInLeftBranch),
|
|
rightChildSize = assert(self.nExampleInRightBranch),
|
|
rightPositiveCount = assert(self.nPositiveInRightBranch),
|
|
gini = assert(gini),
|
|
splitGain = gini
|
|
}
|
|
return splitInfo
|
|
end |