aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/LogitBoostCriterion.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/decisiontree/LogitBoostCriterion.lua')
-rw-r--r--contrib/lua-torch/decisiontree/LogitBoostCriterion.lua45
1 files changed, 45 insertions, 0 deletions
diff --git a/contrib/lua-torch/decisiontree/LogitBoostCriterion.lua b/contrib/lua-torch/decisiontree/LogitBoostCriterion.lua
new file mode 100644
index 000000000..5b9eb6028
--- /dev/null
+++ b/contrib/lua-torch/decisiontree/LogitBoostCriterion.lua
@@ -0,0 +1,45 @@
+local dt = require "decisiontree._env"
+
+-- Ref: slide 17 in https://homes.cs.washington.edu/~tqchen/pdf/BoostedTree.pdf
+
+-- equivalent to nn.Sigmoid() + nn.BCECriterion()
+local LogitBoostCriterion, parent = torch.class("nn.LogitBoostCriterion", "nn.Criterion")
+
+function LogitBoostCriterion:__init(sizeAverage)
+ parent.__init(self)
+ self.sizeAverage = sizeAverage
+ self.hessInput = self.gradInput.new()
+ self._output = torch.Tensor()
+end
+
+function LogitBoostCriterion:updateOutput(input, target)
+ input.nn.LogitBoostCriterion_updateOutput(input, target, self._output, self.sizeAverage)
+ self.output = self._output[1]
+ return self.output
+end
+
+function LogitBoostCriterion:updateGradInput(input, target)
+ input.nn.LogitBoostCriterion_updateGradInput(input, target, self.gradInput)
+ return self.gradInput
+end
+
+function LogitBoostCriterion:updateHessInput(input, target)
+ input.nn.LogitBoostCriterion_updateHessInput(input, target, self.hessInput)
+ return self.hessInput
+end
+
+-- returns gradInput and hessInput
+function LogitBoostCriterion:backward2(input, target)
+ return self:updateGradInput(input, target), self:updateHessInput(input, target)
+end
+
+local gradWrapper = function(input, target, grad)
+ input.nn.LogitBoostCriterion_updateGradInput(input, target, grad)
+end
+local hessianWrapper = function(input, target, hessian)
+ input.nn.LogitBoostCriterion_updateHessInput(input, target, hessian)
+end
+
+function LogitBoostCriterion:getWrappers()
+ return gradWrapper, hessianWrapper
+end