Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

CrossEntropyCriterion.lua 1.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. local CrossEntropyCriterion, Criterion = torch.class('nn.CrossEntropyCriterion', 'nn.Criterion')
  2. function CrossEntropyCriterion:__init(weights, sizeAverage)
  3. Criterion.__init(self)
  4. self.lsm = nn.LogSoftMax()
  5. self.nll = nn.ClassNLLCriterion(weights, sizeAverage)
  6. self.sizeAverage = self.nll.sizeAverage
  7. self.oldSizeAverage = self.sizeAverage
  8. end
  9. function CrossEntropyCriterion:updateOutput(input, target)
  10. input = input:squeeze()
  11. target = type(target) == 'number' and target or target:squeeze()
  12. -- only propagate if value has changed to preserve old behavior
  13. -- of setting nll.sizeAverage directly
  14. if self.sizeAverage ~= self.oldSizeAverage then
  15. self.nll.sizeAverage = self.sizeAverage
  16. end
  17. self.lsm:updateOutput(input)
  18. self.nll:updateOutput(self.lsm.output, target)
  19. self.output = self.nll.output
  20. self.oldSizeAverage = self.sizeAverage
  21. return self.output
  22. end
  23. function CrossEntropyCriterion:updateGradInput(input, target)
  24. local size = input:size()
  25. input = input:squeeze()
  26. target = type(target) == 'number' and target or target:squeeze()
  27. -- only propagate if value has changed to preserve old behavior
  28. -- of setting nll.sizeAverage directly
  29. if self.sizeAverage ~= self.oldSizeAverage then
  30. self.nll.sizeAverage = self.sizeAverage
  31. end
  32. self.nll:updateGradInput(self.lsm.output, target)
  33. self.lsm:updateGradInput(input, self.nll.gradInput)
  34. self.gradInput:view(self.lsm.gradInput, size)
  35. self.oldSizeAverage = self.sizeAverage
  36. return self.gradInput
  37. end
  38. return nn.CrossEntropyCriterion