You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CAdd.lua 3.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. local CAdd, parent = torch.class("nn.CAdd", "nn.Module")
  2. function CAdd:__init(...)
  3. parent.__init(self)
  4. local arg = {...}
  5. self.size = torch.LongStorage()
  6. local n = #arg
  7. if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then
  8. self.size:resize(#arg[1]):copy(arg[1])
  9. else
  10. self.size:resize(n)
  11. for i=1,n do
  12. self.size[i] = arg[i]
  13. end
  14. end
  15. self.bias = torch.Tensor(self.size)
  16. self.gradBias = torch.Tensor(self.size)
  17. self.output:resize(self.size)
  18. self:reset()
  19. end
  20. function CAdd:reset(stdv)
  21. if stdv then
  22. --std of uniform distribution on interval [-a,a] = a/sqrt(3)
  23. stdv = stdv * math.sqrt(3)
  24. else
  25. stdv = 1.0/math.sqrt(self.bias:nElement())
  26. end
  27. self.bias:uniform(-stdv,stdv)
  28. end
  29. function CAdd:updateOutput(input)
  30. self._output = self._output or input.new()
  31. self._bias = self._bias or input.new()
  32. self._expand = self._expand or input.new()
  33. self._repeat = self._repeat or input.new()
  34. self.output:resizeAs(input):copy(input)
  35. if input:nElement() == self.bias:nElement() then
  36. self.output:add(self.bias)
  37. else
  38. if self.bias:dim() == input:dim() then
  39. self._output:set(self.output)
  40. self._bias:set(self.bias)
  41. else
  42. local batchSize = input:size(1)
  43. self._output:view(self.output, batchSize, -1)
  44. self._bias:view(self.bias, 1, -1)
  45. end
  46. self._expand:expandAs(self._bias, self._output)
  47. --expandAs uses stride 0 and self._expand is not contiguous
  48. --cuda ops may assume contiguous input
  49. if torch.type(input) == 'torch.CudaTensor' then
  50. self._repeat:resizeAs(self._expand):copy(self._expand)
  51. self._output:add(self._repeat)
  52. else
  53. self._output:add(self._expand)
  54. end
  55. end
  56. return self.output
  57. end
  58. function CAdd:updateGradInput(input, gradOutput)
  59. self.gradInput = self.gradInput or input.new()
  60. self.gradInput:resizeAs(gradOutput):copy(gradOutput)
  61. return self.gradInput
  62. end
  63. function CAdd:accGradParameters(input, gradOutput, scale)
  64. scale = scale or 1
  65. self._gradBias = self._gradBias or gradOutput.new()
  66. self._gradOutput = self._gradOutput or gradOutput.new()
  67. self._repeat = self._repeat or gradOutput.new()
  68. if self.bias:nElement() == gradOutput:nElement() then
  69. self.gradBias:add(scale, gradOutput)
  70. else
  71. if self.bias:dim() == gradOutput:dim() then
  72. self._gradBias:set(self.gradBias)
  73. self._gradOutput:set(gradOutput)
  74. else
  75. local batchSize = input:size(1)
  76. self._gradBias:view(self.gradBias, 1, -1)
  77. self._gradOutput:view(gradOutput, batchSize, -1)
  78. end
  79. self._gradBias:expandAs(self._gradBias, self._gradOutput)
  80. --expandAs uses stride 0 and self._gradBias is not contiguous
  81. --cuda ops may assume contiguous input
  82. if torch.type(self._gradBias) == 'torch.CudaTensor' then
  83. self._repeat:resizeAs(self._gradBias):copy(self._gradBias)
  84. self._repeat:add(scale, self._gradOutput)
  85. self._gradBias:copy(self._repeat)
  86. else
  87. self._gradBias:add(scale, self._gradOutput)
  88. end
  89. end
  90. end
  91. function CAdd:type(type, tensorCache)
  92. if type then
  93. self:clearState()
  94. end
  95. return parent.type(self, type, tensorCache)
  96. end
  97. function CAdd:clearState()
  98. nn.utils.clear(self, {
  99. '_gradBias',
  100. '_expand',
  101. '_output',
  102. '_bias',
  103. '_repeat'
  104. })
  105. return parent.clearState(self)
  106. end