aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/GiniState.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-23 18:14:15 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-23 18:14:15 +0100
commit714eb56e1760fdfb26afccde92664d3a2f1e8435 (patch)
tree84d1399acbb92f852b4bd64f9ea5412680b0c6ab /contrib/lua-torch/decisiontree/GiniState.lua
parent220a51ff68013dd668a45b78c60a7b8bfc10f074 (diff)
downloadrspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.tar.gz
rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.zip
[Minor] Move lua contrib libraries to lua- prefix
Diffstat (limited to 'contrib/lua-torch/decisiontree/GiniState.lua')
-rw-r--r--contrib/lua-torch/decisiontree/GiniState.lua54
1 files changed, 54 insertions, 0 deletions
diff --git a/contrib/lua-torch/decisiontree/GiniState.lua b/contrib/lua-torch/decisiontree/GiniState.lua
new file mode 100644
index 000000000..6dfed2845
--- /dev/null
+++ b/contrib/lua-torch/decisiontree/GiniState.lua
@@ -0,0 +1,54 @@
+local dt = require 'decisiontree._env'
+
+-- used by RandomForestTrainer
+local GiniState, parent = torch.class("dt.GiniState", "dt.TreeState", dt)
+
+function GiniState:__init(exampleIds)
+ parent.__init(self, exampleIds)
+ self.nPositiveInLeftBranch = 0
+ self.nPositiveInRightBranch = 0
+end
+
+function GiniState:score(dataset)
+ local dt = require 'decisiontree'
+ local nPositive = dataset:countPositive(self.exampleIds)
+ return dt.calculateLogitScore(nPositive, self.exampleIds:size(1))
+end
+
+function GiniState:initialize(exampleIdsWithFeature, dataset)
+ assert(torch.type(exampleIdsWithFeature) == 'torch.LongTensor')
+ assert(torch.isTypeOf(dataset, 'dt.DataSet'))
+ self.nPositiveInLeftBranch = dataset:countPositive(exampleIdsWithFeature)
+ self.nPositiveInRightBranch = 0
+
+ self.nExampleInLeftBranch = exampleIdsWithFeature:size(1)
+ self.nExampleInRightBranch = 0
+end
+
+function GiniState:update(exampleId, dataset)
+ assert(torch.type(exampleId) == 'number')
+ assert(torch.isTypeOf(dataset, 'dt.DataSet'))
+ if dataset.target[exampleId] > 0 then
+ self.nPositiveInLeftBranch = self.nPositiveInLeftBranch - 1
+ self.nPositiveInRightBranch = self.nPositiveInRightBranch + 1
+ end
+
+ self.nExampleInLeftBranch = self.nExampleInLeftBranch - 1
+ self.nExampleInRightBranch = self.nExampleInRightBranch + 1
+end
+
+function GiniState:computeSplitInfo(splitFeatureId, splitFeatureValue)
+ local dt = require 'decisiontree'
+ local gini = dt.computeGini(self.nExampleInLeftBranch, self.nPositiveInLeftBranch, self.nExampleInRightBranch, self.nPositiveInRightBranch)
+ local splitInfo = {
+ splitId = assert(splitFeatureId),
+ splitValue = assert(splitFeatureValue),
+ leftChildSize = assert(self.nExampleInLeftBranch),
+ leftPositiveCount = assert(self.nPositiveInLeftBranch),
+ rightChildSize = assert(self.nExampleInRightBranch),
+ rightPositiveCount = assert(self.nPositiveInRightBranch),
+ gini = assert(gini),
+ splitGain = gini
+ }
+ return splitInfo
+end \ No newline at end of file