You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CMul.lua 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. local CMul, parent = torch.class('nn.CMul', 'nn.Module')
  2. function CMul:__init(...)
  3. parent.__init(self)
  4. local arg = {...}
  5. self.size = torch.LongStorage()
  6. local n = #arg
  7. if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then
  8. self.size:resize(#arg[1]):copy(arg[1])
  9. else
  10. self.size:resize(n)
  11. for i=1,n do
  12. self.size[i] = arg[i]
  13. end
  14. end
  15. self.weight = torch.Tensor(self.size)
  16. self.gradWeight = torch.Tensor(self.size)
  17. self.output:resize(self.size)
  18. self:reset()
  19. end
  20. function CMul:reset(stdv)
  21. if stdv then
  22. stdv = stdv * math.sqrt(3)
  23. else
  24. stdv = 1./math.sqrt(self.weight:nElement())
  25. end
  26. self.weight:uniform(-stdv,stdv)
  27. end
  28. function CMul:updateOutput(input)
  29. -- lazy-initialize
  30. self._output = self._output or input.new()
  31. self._weight = self._weight or input.new()
  32. self._expand = self._expand or input.new()
  33. self._repeat = self._repeat or input.new()
  34. self.output:resizeAs(input):copy(input)
  35. if input:nElement() == self.weight:nElement() then
  36. self._output:view(self.output, -1)
  37. self._weight:view(self.weight, -1)
  38. self._output:cmul(self._weight)
  39. else
  40. if self.weight:dim() == input:dim() then
  41. self._output:set(self.output)
  42. self._weight:set(self.weight)
  43. else
  44. local batchSize = input:size(1)
  45. self._output:view(self.output, batchSize, -1)
  46. self._weight:view(self.weight, 1, -1)
  47. end
  48. self._expand:expandAs(self._weight, self._output)
  49. if torch.type(input) == 'torch.CudaTensor' then
  50. self._repeat:resizeAs(self._expand):copy(self._expand)
  51. self._output:cmul(self._repeat)
  52. else
  53. self._output:cmul(self._expand)
  54. end
  55. end
  56. return self.output
  57. end
  58. function CMul:updateGradInput(input, gradOutput)
  59. if not self.gradInput then
  60. return
  61. end
  62. self._gradOutput = self._gradOutput or input.new()
  63. self._gradInput = self._gradInput or input.new()
  64. self.gradInput:resizeAs(input):zero()
  65. if self.weight:nElement() == gradOutput:nElement() then
  66. self.gradInput:addcmul(1, self.weight, gradOutput)
  67. else
  68. if self.weight:dim() == input:dim() then
  69. nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size())
  70. nn.utils.contiguousView(self._gradInput, self.gradInput, self.gradInput:size())
  71. self._weight:set(self.weight)
  72. else
  73. local batchSize = input:size(1)
  74. nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1)
  75. nn.utils.contiguousView(self._gradInput, self.gradInput, batchSize, -1)
  76. self._weight:view(self.weight, 1, -1)
  77. end
  78. self._expand:expandAs(self._weight, self._gradOutput)
  79. if torch.type(input) == 'torch.CudaTensor' then
  80. self._repeat:resizeAs(self._expand):copy(self._expand)
  81. self._gradInput:addcmul(1, self._repeat, self._gradOutput)
  82. else
  83. self._gradInput:addcmul(1, self._expand, self._gradOutput)
  84. end
  85. end
  86. return self.gradInput
  87. end
  88. function CMul:accGradParameters(input, gradOutput, scale)
  89. scale = scale or 1
  90. self._input = self._input or input.new()
  91. self._gradWeight = self._gradWeight or input.new()
  92. self._sum = self._sum or input.new()
  93. if self.weight:nElement() == gradOutput:nElement() then
  94. self.gradWeight:addcmul(scale, input, gradOutput)
  95. else
  96. if self.weight:dim() == input:dim() then
  97. nn.utils.contiguousView(self._input, input, input:size())
  98. nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size())
  99. self._gradWeight:set(self.gradWeight)
  100. self._repeat:cmul(self._input, self._gradOutput)
  101. local sumInto = self._sum
  102. local sumFrom = self._repeat
  103. for i=1,self.weight:dim() do
  104. if self.weight:size(i) ~= input:size(i) then
  105. sumInto:sum(sumFrom, i)
  106. sumInto = sumFrom
  107. sumFrom = sumFrom == self._repeat and self._sum or self._repeat
  108. end
  109. end
  110. self._gradWeight:add(scale, sumFrom)
  111. else
  112. local batchSize = input:size(1)
  113. nn.utils.contiguousView(self._input, input, batchSize, -1)
  114. nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1)
  115. self._gradWeight:view(self.gradWeight, 1, -1)
  116. self._repeat:cmul(self._input, self._gradOutput)
  117. self._sum:sum(self._repeat, 1)
  118. self._gradWeight:add(scale, self._sum)
  119. end
  120. end
  121. end
  122. function CMul:type(type, tensorCache)
  123. if type then
  124. self:clearState()
  125. end
  126. return parent.type(self, type, tensorCache)
  127. end
  128. function CMul:clearState()
  129. nn.utils.clear(self, {
  130. '_input',
  131. '_output',
  132. '_weight',
  133. '_gradWeight',
  134. '_expand',
  135. '_repeat',
  136. '_sum',
  137. })
  138. return parent.clearState(self)
  139. end