diff options
Diffstat (limited to 'contrib/lua-torch/nn/SpatialClassNLLCriterion.lua')
-rw-r--r-- | contrib/lua-torch/nn/SpatialClassNLLCriterion.lua | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SpatialClassNLLCriterion.lua b/contrib/lua-torch/nn/SpatialClassNLLCriterion.lua new file mode 100644 index 000000000..fbd367410 --- /dev/null +++ b/contrib/lua-torch/nn/SpatialClassNLLCriterion.lua @@ -0,0 +1,81 @@ +local THNN = require 'nn.THNN' +local SpatialClassNLLCriterion, parent = torch.class('nn.SpatialClassNLLCriterion', 'nn.Criterion') + +function SpatialClassNLLCriterion:__init(weights, sizeAverage) + parent.__init(self) + if sizeAverage ~= nil then + self.sizeAverage = sizeAverage + else + self.sizeAverage = true + end + if weights then + assert(weights:dim() == 1, "weights input should be 1-D Tensor") + self.weights = weights + end + + self.output_tensor = torch.zeros(1) + self.total_weight_tensor = torch.ones(1) + self.target = torch.zeros(1):long() +end + +function SpatialClassNLLCriterion:__len() + if (self.weights) then + return #self.weights + else + return 0 + end +end + +function SpatialClassNLLCriterion:updateOutput(input, target) + if type(target) == 'number' then + if torch.typename(input):find('torch%.Cuda.*Tensor') then + self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() + else + self.target = self.target:long() + end + self.target[1] = target + elseif torch.typename(input):find('torch%.Cuda.*Tensor') then + self.target = torch.CudaLongTensor and target:cudaLong() or target + else + self.target = target:long() + end + + input.THNN.SpatialClassNLLCriterion_updateOutput( + input:cdata(), + self.target:cdata(), + self.output_tensor:cdata(), + self.sizeAverage, + THNN.optionalTensor(self.weights), + self.total_weight_tensor:cdata() + ) + self.output = self.output_tensor[1] + return self.output, self.total_weight_tensor[1] +end + +function SpatialClassNLLCriterion:updateGradInput(input, target) + if type(target) == 'number' then + if torch.typename(input):find('torch%.Cuda.*Tensor') then + self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() + else + self.target = self.target:long() + end + self.target[1] = target + elseif torch.typename(input):find('torch%.Cuda.*Tensor') then + self.target = torch.CudaLongTensor and target:cudaLong() or target + else + self.target = target:long() + end + + self.gradInput:resizeAs(input):zero() + + input.THNN.SpatialClassNLLCriterion_updateGradInput( + input:cdata(), + self.target:cdata(), + self.gradInput:cdata(), + self.sizeAverage, + THNN.optionalTensor(self.weights), + self.total_weight_tensor:cdata() + ) + + return self.gradInput +end |