diff options
Diffstat (limited to 'contrib/lua-torch/nn/CrossEntropyCriterion.lua')
-rw-r--r-- | contrib/lua-torch/nn/CrossEntropyCriterion.lua | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/CrossEntropyCriterion.lua b/contrib/lua-torch/nn/CrossEntropyCriterion.lua new file mode 100644 index 000000000..2f72cf87f --- /dev/null +++ b/contrib/lua-torch/nn/CrossEntropyCriterion.lua @@ -0,0 +1,42 @@ +local CrossEntropyCriterion, Criterion = torch.class('nn.CrossEntropyCriterion', 'nn.Criterion') + +function CrossEntropyCriterion:__init(weights, sizeAverage) + Criterion.__init(self) + self.lsm = nn.LogSoftMax() + self.nll = nn.ClassNLLCriterion(weights, sizeAverage) + self.sizeAverage = self.nll.sizeAverage + self.oldSizeAverage = self.sizeAverage +end + +function CrossEntropyCriterion:updateOutput(input, target) + input = input:squeeze() + target = type(target) == 'number' and target or target:squeeze() + -- only propagate if value has changed to preserve old behavior + -- of setting nll.sizeAverage directly + if self.sizeAverage ~= self.oldSizeAverage then + self.nll.sizeAverage = self.sizeAverage + end + self.lsm:updateOutput(input) + self.nll:updateOutput(self.lsm.output, target) + self.output = self.nll.output + self.oldSizeAverage = self.sizeAverage + return self.output +end + +function CrossEntropyCriterion:updateGradInput(input, target) + local size = input:size() + input = input:squeeze() + target = type(target) == 'number' and target or target:squeeze() + -- only propagate if value has changed to preserve old behavior + -- of setting nll.sizeAverage directly + if self.sizeAverage ~= self.oldSizeAverage then + self.nll.sizeAverage = self.sizeAverage + end + self.nll:updateGradInput(self.lsm.output, target) + self.lsm:updateGradInput(input, self.nll.gradInput) + self.gradInput:view(self.lsm.gradInput, size) + self.oldSizeAverage = self.sizeAverage + return self.gradInput +end + +return nn.CrossEntropyCriterion |