diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
commit | 714eb56e1760fdfb26afccde92664d3a2f1e8435 (patch) | |
tree | 84d1399acbb92f852b4bd64f9ea5412680b0c6ab /contrib/lua-torch/decisiontree/GiniState.lua | |
parent | 220a51ff68013dd668a45b78c60a7b8bfc10f074 (diff) | |
download | rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.tar.gz rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.zip |
[Minor] Move lua contrib libraries to lua- prefix
Diffstat (limited to 'contrib/lua-torch/decisiontree/GiniState.lua')
-rw-r--r-- | contrib/lua-torch/decisiontree/GiniState.lua | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/contrib/lua-torch/decisiontree/GiniState.lua b/contrib/lua-torch/decisiontree/GiniState.lua new file mode 100644 index 000000000..6dfed2845 --- /dev/null +++ b/contrib/lua-torch/decisiontree/GiniState.lua @@ -0,0 +1,54 @@ +local dt = require 'decisiontree._env' + +-- used by RandomForestTrainer +local GiniState, parent = torch.class("dt.GiniState", "dt.TreeState", dt) + +function GiniState:__init(exampleIds) + parent.__init(self, exampleIds) + self.nPositiveInLeftBranch = 0 + self.nPositiveInRightBranch = 0 +end + +function GiniState:score(dataset) + local dt = require 'decisiontree' + local nPositive = dataset:countPositive(self.exampleIds) + return dt.calculateLogitScore(nPositive, self.exampleIds:size(1)) +end + +function GiniState:initialize(exampleIdsWithFeature, dataset) + assert(torch.type(exampleIdsWithFeature) == 'torch.LongTensor') + assert(torch.isTypeOf(dataset, 'dt.DataSet')) + self.nPositiveInLeftBranch = dataset:countPositive(exampleIdsWithFeature) + self.nPositiveInRightBranch = 0 + + self.nExampleInLeftBranch = exampleIdsWithFeature:size(1) + self.nExampleInRightBranch = 0 +end + +function GiniState:update(exampleId, dataset) + assert(torch.type(exampleId) == 'number') + assert(torch.isTypeOf(dataset, 'dt.DataSet')) + if dataset.target[exampleId] > 0 then + self.nPositiveInLeftBranch = self.nPositiveInLeftBranch - 1 + self.nPositiveInRightBranch = self.nPositiveInRightBranch + 1 + end + + self.nExampleInLeftBranch = self.nExampleInLeftBranch - 1 + self.nExampleInRightBranch = self.nExampleInRightBranch + 1 +end + +function GiniState:computeSplitInfo(splitFeatureId, splitFeatureValue) + local dt = require 'decisiontree' + local gini = dt.computeGini(self.nExampleInLeftBranch, self.nPositiveInLeftBranch, self.nExampleInRightBranch, self.nPositiveInRightBranch) + local splitInfo = { + splitId = assert(splitFeatureId), + splitValue = assert(splitFeatureValue), + leftChildSize = assert(self.nExampleInLeftBranch), + leftPositiveCount = assert(self.nPositiveInLeftBranch), + rightChildSize = assert(self.nExampleInRightBranch), + rightPositiveCount = assert(self.nPositiveInRightBranch), + gini = assert(gini), + splitGain = gini + } + return splitInfo +end
\ No newline at end of file |