12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- local THNN = require 'nn.THNN'
- local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion')
-
- function ClassNLLCriterion:__init(weights, sizeAverage, ignoreIndex)
- parent.__init(self)
- self.sizeAverage = (sizeAverage == nil) and true or sizeAverage
- self.ignoreIndex = ignoreIndex or -100 -- this target index will be ignored
- if weights then
- assert(weights:dim() == 1, "weights input should be 1-D Tensor")
- self.weights = weights
- end
-
- self.output_tensor = torch.zeros(1)
- self.total_weight_tensor = torch.ones(1)
- self.target = torch.zeros(1):long()
- end
-
- function ClassNLLCriterion:__len()
- if (self.weights) then
- return #self.weights
- else
- return 0
- end
- end
-
- function ClassNLLCriterion:updateOutput(input, target)
- if type(target) == 'number' then
- if torch.typename(input):find('torch%.Cuda.*Tensor') then
- self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
- else
- self.target = self.target:long()
- end
- self.target:resize(1)
- self.target[1] = target
- elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
- self.target = torch.CudaLongTensor and target:cudaLong() or target
- else
- self.target = target:long()
- end
-
- input.THNN.ClassNLLCriterion_updateOutput(
- input:cdata(),
- self.target:cdata(),
- self.output_tensor:cdata(),
- self.sizeAverage,
- THNN.optionalTensor(self.weights),
- self.total_weight_tensor:cdata(),
- self.ignoreIndex
- )
- self.output = self.output_tensor[1]
- return self.output, self.total_weight_tensor[1]
- end
-
- function ClassNLLCriterion:updateGradInput(input, target)
- if type(target) == 'number' then
- if torch.typename(input):find('torch%.Cuda.*Tensor') then
- self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
- else
- self.target = self.target:long()
- end
- self.target:resize(1)
- self.target[1] = target
- elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
- self.target = torch.CudaLongTensor and target:cudaLong() or target
- else
- self.target = target:long()
- end
-
- self.gradInput:resizeAs(input):zero()
-
- input.THNN.ClassNLLCriterion_updateGradInput(
- input:cdata(),
- self.target:cdata(),
- self.gradInput:cdata(),
- self.sizeAverage,
- THNN.optionalTensor(self.weights),
- self.total_weight_tensor:cdata(),
- self.ignoreIndex
- )
-
- return self.gradInput
- end
|