您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

CMinTable.lua 1.6KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. local CMinTable, parent = torch.class('nn.CMinTable', 'nn.Module')
  2. function CMinTable:__init()
  3. parent.__init(self)
  4. self.gradInput = {}
  5. self.minIdx = torch.Tensor()
  6. self.mask = torch.Tensor()
  7. self.minVals = torch.Tensor()
  8. self.gradMaxVals = torch.Tensor()
  9. end
  10. function CMinTable:updateOutput(input)
  11. self.output:resizeAs(input[1]):copy(input[1])
  12. self.minIdx:resizeAs(input[1]):fill(1)
  13. for i=2,#input do
  14. self.maskByteTensor = self.maskByteTensor or
  15. (torch.type(self.output) == 'torch.CudaTensor' and
  16. torch.CudaByteTensor() or torch.ByteTensor())
  17. self.mask:lt(input[i], self.output)
  18. self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
  19. self.minIdx:maskedFill(self.maskByteTensor, i)
  20. self.minVals:maskedSelect(input[i], self.maskByteTensor)
  21. self.output:maskedCopy(self.maskByteTensor, self.minVals)
  22. end
  23. return self.output
  24. end
  25. function CMinTable:updateGradInput(input, gradOutput)
  26. for i=1,#input do
  27. self.gradInput[i] = self.gradInput[i] or input[i].new()
  28. self.gradInput[i]:resizeAs(input[i]):fill(0.0)
  29. self.maskByteTensor = self.maskByteTensor or
  30. (torch.type(self.output) == 'torch.CudaTensor' and
  31. torch.CudaByteTensor() or torch.ByteTensor())
  32. self.mask:eq(self.minIdx, i)
  33. self.maskByteTensor:resize(self.mask:size()):copy(self.mask)
  34. self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor)
  35. self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals)
  36. end
  37. for i=#input+1, #self.gradInput do
  38. self.gradInput[i] = nil
  39. end
  40. return self.gradInput
  41. end