summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/SpatialClassNLLCriterion.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/SpatialClassNLLCriterion.lua')
-rw-r--r--contrib/lua-torch/nn/SpatialClassNLLCriterion.lua81
1 files changed, 81 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SpatialClassNLLCriterion.lua b/contrib/lua-torch/nn/SpatialClassNLLCriterion.lua
new file mode 100644
index 000000000..fbd367410
--- /dev/null
+++ b/contrib/lua-torch/nn/SpatialClassNLLCriterion.lua
@@ -0,0 +1,81 @@
+local THNN = require 'nn.THNN'
+local SpatialClassNLLCriterion, parent = torch.class('nn.SpatialClassNLLCriterion', 'nn.Criterion')
+
+function SpatialClassNLLCriterion:__init(weights, sizeAverage)
+ parent.__init(self)
+ if sizeAverage ~= nil then
+ self.sizeAverage = sizeAverage
+ else
+ self.sizeAverage = true
+ end
+ if weights then
+ assert(weights:dim() == 1, "weights input should be 1-D Tensor")
+ self.weights = weights
+ end
+
+ self.output_tensor = torch.zeros(1)
+ self.total_weight_tensor = torch.ones(1)
+ self.target = torch.zeros(1):long()
+end
+
+function SpatialClassNLLCriterion:__len()
+ if (self.weights) then
+ return #self.weights
+ else
+ return 0
+ end
+end
+
+function SpatialClassNLLCriterion:updateOutput(input, target)
+ if type(target) == 'number' then
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
+ else
+ self.target = self.target:long()
+ end
+ self.target[1] = target
+ elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.target = torch.CudaLongTensor and target:cudaLong() or target
+ else
+ self.target = target:long()
+ end
+
+ input.THNN.SpatialClassNLLCriterion_updateOutput(
+ input:cdata(),
+ self.target:cdata(),
+ self.output_tensor:cdata(),
+ self.sizeAverage,
+ THNN.optionalTensor(self.weights),
+ self.total_weight_tensor:cdata()
+ )
+ self.output = self.output_tensor[1]
+ return self.output, self.total_weight_tensor[1]
+end
+
+function SpatialClassNLLCriterion:updateGradInput(input, target)
+ if type(target) == 'number' then
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
+ else
+ self.target = self.target:long()
+ end
+ self.target[1] = target
+ elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.target = torch.CudaLongTensor and target:cudaLong() or target
+ else
+ self.target = target:long()
+ end
+
+ self.gradInput:resizeAs(input):zero()
+
+ input.THNN.SpatialClassNLLCriterion_updateGradInput(
+ input:cdata(),
+ self.target:cdata(),
+ self.gradInput:cdata(),
+ self.sizeAverage,
+ THNN.optionalTensor(self.weights),
+ self.total_weight_tensor:cdata()
+ )
+
+ return self.gradInput
+end