You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DistKLDivCriterion.lua 1.1KB

12345678910111213141516171819202122232425262728293031323334
  1. local DistKLDivCriterion, parent = torch.class('nn.DistKLDivCriterion', 'nn.Criterion')
  2. function DistKLDivCriterion:__init()
  3. parent.__init(self)
  4. self.sizeAverage = true
  5. end
  6. function DistKLDivCriterion:updateOutput(input, target)
  7. assert(input:dim() == target:dim() and
  8. torch.LongTensor(input:size()):eq(torch.LongTensor(target:size())):all(),
  9. 'input and target should have the same size')
  10. self.output_tensor = self.output_tensor or input.new(1)
  11. input.THNN.DistKLDivCriterion_updateOutput(
  12. input:cdata(),
  13. target:cdata(),
  14. self.output_tensor:cdata(),
  15. self.sizeAverage
  16. )
  17. self.output = self.output_tensor[1]
  18. return self.output
  19. end
  20. function DistKLDivCriterion:updateGradInput(input, target)
  21. assert(input:dim() == target:dim() and
  22. torch.LongTensor(input:size()):eq(torch.LongTensor(target:size())):all(),
  23. 'input and target should have the same size')
  24. input.THNN.DistKLDivCriterion_updateGradInput(
  25. input:cdata(),
  26. target:cdata(),
  27. self.gradInput:cdata(),
  28. self.sizeAverage
  29. )
  30. return self.gradInput
  31. end