diff options
Diffstat (limited to 'contrib/lua-torch/decisiontree/LogitBoostCriterion.lua')
-rw-r--r-- | contrib/lua-torch/decisiontree/LogitBoostCriterion.lua | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/contrib/lua-torch/decisiontree/LogitBoostCriterion.lua b/contrib/lua-torch/decisiontree/LogitBoostCriterion.lua new file mode 100644 index 000000000..5b9eb6028 --- /dev/null +++ b/contrib/lua-torch/decisiontree/LogitBoostCriterion.lua @@ -0,0 +1,45 @@ +local dt = require "decisiontree._env" + +-- Ref: slide 17 in https://homes.cs.washington.edu/~tqchen/pdf/BoostedTree.pdf + +-- equivalent to nn.Sigmoid() + nn.BCECriterion() +local LogitBoostCriterion, parent = torch.class("nn.LogitBoostCriterion", "nn.Criterion") + +function LogitBoostCriterion:__init(sizeAverage) + parent.__init(self) + self.sizeAverage = sizeAverage + self.hessInput = self.gradInput.new() + self._output = torch.Tensor() +end + +function LogitBoostCriterion:updateOutput(input, target) + input.nn.LogitBoostCriterion_updateOutput(input, target, self._output, self.sizeAverage) + self.output = self._output[1] + return self.output +end + +function LogitBoostCriterion:updateGradInput(input, target) + input.nn.LogitBoostCriterion_updateGradInput(input, target, self.gradInput) + return self.gradInput +end + +function LogitBoostCriterion:updateHessInput(input, target) + input.nn.LogitBoostCriterion_updateHessInput(input, target, self.hessInput) + return self.hessInput +end + +-- returns gradInput and hessInput +function LogitBoostCriterion:backward2(input, target) + return self:updateGradInput(input, target), self:updateHessInput(input, target) +end + +local gradWrapper = function(input, target, grad) + input.nn.LogitBoostCriterion_updateGradInput(input, target, grad) +end +local hessianWrapper = function(input, target, hessian) + input.nn.LogitBoostCriterion_updateHessInput(input, target, hessian) +end + +function LogitBoostCriterion:getWrappers() + return gradWrapper, hessianWrapper +end |