You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ClassNLLCriterion.lua 2.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. local THNN = require 'nn.THNN'
  2. local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion')
  3. function ClassNLLCriterion:__init(weights, sizeAverage, ignoreIndex)
  4. parent.__init(self)
  5. self.sizeAverage = (sizeAverage == nil) and true or sizeAverage
  6. self.ignoreIndex = ignoreIndex or -100 -- this target index will be ignored
  7. if weights then
  8. assert(weights:dim() == 1, "weights input should be 1-D Tensor")
  9. self.weights = weights
  10. end
  11. self.output_tensor = torch.zeros(1)
  12. self.total_weight_tensor = torch.ones(1)
  13. self.target = torch.zeros(1):long()
  14. end
  15. function ClassNLLCriterion:__len()
  16. if (self.weights) then
  17. return #self.weights
  18. else
  19. return 0
  20. end
  21. end
  22. function ClassNLLCriterion:updateOutput(input, target)
  23. if type(target) == 'number' then
  24. if torch.typename(input):find('torch%.Cuda.*Tensor') then
  25. self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
  26. else
  27. self.target = self.target:long()
  28. end
  29. self.target:resize(1)
  30. self.target[1] = target
  31. elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
  32. self.target = torch.CudaLongTensor and target:cudaLong() or target
  33. else
  34. self.target = target:long()
  35. end
  36. input.THNN.ClassNLLCriterion_updateOutput(
  37. input:cdata(),
  38. self.target:cdata(),
  39. self.output_tensor:cdata(),
  40. self.sizeAverage,
  41. THNN.optionalTensor(self.weights),
  42. self.total_weight_tensor:cdata(),
  43. self.ignoreIndex
  44. )
  45. self.output = self.output_tensor[1]
  46. return self.output, self.total_weight_tensor[1]
  47. end
  48. function ClassNLLCriterion:updateGradInput(input, target)
  49. if type(target) == 'number' then
  50. if torch.typename(input):find('torch%.Cuda.*Tensor') then
  51. self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
  52. else
  53. self.target = self.target:long()
  54. end
  55. self.target:resize(1)
  56. self.target[1] = target
  57. elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
  58. self.target = torch.CudaLongTensor and target:cudaLong() or target
  59. else
  60. self.target = target:long()
  61. end
  62. self.gradInput:resizeAs(input):zero()
  63. input.THNN.ClassNLLCriterion_updateGradInput(
  64. input:cdata(),
  65. self.target:cdata(),
  66. self.gradInput:cdata(),
  67. self.sizeAverage,
  68. THNN.optionalTensor(self.weights),
  69. self.total_weight_tensor:cdata(),
  70. self.ignoreIndex
  71. )
  72. return self.gradInput
  73. end