You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Euclidean.lua 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. local Euclidean, parent = torch.class('nn.Euclidean', 'nn.Module')
  2. function Euclidean:__init(inputSize,outputSize)
  3. parent.__init(self)
  4. self.weight = torch.Tensor(inputSize,outputSize)
  5. self.gradWeight = torch.Tensor(inputSize,outputSize)
  6. -- state
  7. self.gradInput:resize(inputSize)
  8. self.output:resize(outputSize)
  9. self.fastBackward = true
  10. self:reset()
  11. end
  12. function Euclidean:reset(stdv)
  13. if stdv then
  14. stdv = stdv * math.sqrt(3)
  15. else
  16. stdv = 1./math.sqrt(self.weight:size(1))
  17. end
  18. if nn.oldSeed then
  19. for i=1,self.weight:size(2) do
  20. self.weight:select(2, i):apply(function()
  21. return torch.uniform(-stdv, stdv)
  22. end)
  23. end
  24. else
  25. self.weight:uniform(-stdv, stdv)
  26. end
  27. end
  28. local function view(res, src, ...)
  29. local args = {...}
  30. if src:isContiguous() then
  31. res:view(src, table.unpack(args))
  32. else
  33. res:reshape(src, table.unpack(args))
  34. end
  35. end
  36. function Euclidean:updateOutput(input)
  37. -- lazy initialize buffers
  38. self._input = self._input or input.new()
  39. self._weight = self._weight or self.weight.new()
  40. self._expand = self._expand or self.output.new()
  41. self._expand2 = self._expand2 or self.output.new()
  42. self._repeat = self._repeat or self.output.new()
  43. self._repeat2 = self._repeat2 or self.output.new()
  44. local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
  45. -- y_j = || w_j - x || = || x - w_j ||
  46. if input:dim() == 1 then
  47. view(self._input, input, inputSize, 1)
  48. self._expand:expandAs(self._input, self.weight)
  49. self._repeat:resizeAs(self._expand):copy(self._expand)
  50. self._repeat:add(-1, self.weight)
  51. self.output:norm(self._repeat, 2, 1)
  52. self.output:resize(outputSize)
  53. elseif input:dim() == 2 then
  54. local batchSize = input:size(1)
  55. view(self._input, input, batchSize, inputSize, 1)
  56. self._expand:expand(self._input, batchSize, inputSize, outputSize)
  57. -- make the expanded tensor contiguous (requires lots of memory)
  58. self._repeat:resizeAs(self._expand):copy(self._expand)
  59. self._weight:view(self.weight, 1, inputSize, outputSize)
  60. self._expand2:expandAs(self._weight, self._repeat)
  61. if torch.type(input) == 'torch.CudaTensor' then
  62. -- requires lots of memory, but minimizes cudaMallocs and loops
  63. self._repeat2:resizeAs(self._expand2):copy(self._expand2)
  64. self._repeat:add(-1, self._repeat2)
  65. else
  66. self._repeat:add(-1, self._expand2)
  67. end
  68. self.output:norm(self._repeat, 2, 2)
  69. self.output:resize(batchSize, outputSize)
  70. else
  71. error"1D or 2D input expected"
  72. end
  73. return self.output
  74. end
  75. function Euclidean:updateGradInput(input, gradOutput)
  76. if not self.gradInput then
  77. return
  78. end
  79. self._div = self._div or input.new()
  80. self._output = self._output or self.output.new()
  81. self._gradOutput = self._gradOutput or input.new()
  82. self._expand3 = self._expand3 or input.new()
  83. if not self.fastBackward then
  84. self:updateOutput(input)
  85. end
  86. local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
  87. --[[
  88. dy_j -2 * (w_j - x) x - w_j
  89. ---- = --------------- = -------
  90. dx 2 || w_j - x || y_j
  91. --]]
  92. -- to prevent div by zero (NaN) bugs
  93. self._output:resizeAs(self.output):copy(self.output):add(0.0000001)
  94. view(self._gradOutput, gradOutput, gradOutput:size())
  95. self._div:cdiv(gradOutput, self._output)
  96. if input:dim() == 1 then
  97. self._div:resize(1, outputSize)
  98. self._expand3:expandAs(self._div, self.weight)
  99. if torch.type(input) == 'torch.CudaTensor' then
  100. self._repeat2:resizeAs(self._expand3):copy(self._expand3)
  101. self._repeat2:cmul(self._repeat)
  102. else
  103. self._repeat2:cmul(self._repeat, self._expand3)
  104. end
  105. self.gradInput:sum(self._repeat2, 2)
  106. self.gradInput:resizeAs(input)
  107. elseif input:dim() == 2 then
  108. local batchSize = input:size(1)
  109. self._div:resize(batchSize, 1, outputSize)
  110. self._expand3:expand(self._div, batchSize, inputSize, outputSize)
  111. if torch.type(input) == 'torch.CudaTensor' then
  112. self._repeat2:resizeAs(self._expand3):copy(self._expand3)
  113. self._repeat2:cmul(self._repeat)
  114. else
  115. self._repeat2:cmul(self._repeat, self._expand3)
  116. end
  117. self.gradInput:sum(self._repeat2, 3)
  118. self.gradInput:resizeAs(input)
  119. else
  120. error"1D or 2D input expected"
  121. end
  122. return self.gradInput
  123. end
  124. function Euclidean:accGradParameters(input, gradOutput, scale)
  125. local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
  126. scale = scale or 1
  127. --[[
  128. dy_j 2 * (w_j - x) w_j - x
  129. ---- = --------------- = -------
  130. dw_j 2 || w_j - x || y_j
  131. --]]
  132. -- assumes a preceding call to updateGradInput
  133. if input:dim() == 1 then
  134. self.gradWeight:add(-scale, self._repeat2)
  135. elseif input:dim() == 2 then
  136. self._sum = self._sum or input.new()
  137. self._sum:sum(self._repeat2, 1)
  138. self._sum:resize(inputSize, outputSize)
  139. self.gradWeight:add(-scale, self._sum)
  140. else
  141. error"1D or 2D input expected"
  142. end
  143. end
  144. function Euclidean:type(type, tensorCache)
  145. if type then
  146. -- prevent premature memory allocations
  147. self:clearState()
  148. end
  149. return parent.type(self, type, tensorCache)
  150. end
  151. function Euclidean:clearState()
  152. nn.utils.clear(self, {
  153. '_input',
  154. '_output',
  155. '_gradOutput',
  156. '_weight',
  157. '_div',
  158. '_sum',
  159. '_expand',
  160. '_expand2',
  161. '_expand3',
  162. '_repeat',
  163. '_repeat2',
  164. })
  165. return parent.clearState(self)
  166. end