You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cosine.lua 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. local Cosine, parent = torch.class('nn.Cosine', 'nn.Module')
  2. function Cosine:__init(inputSize,outputSize)
  3. parent.__init(self)
  4. self.weight = torch.Tensor(outputSize,inputSize)
  5. self.gradWeight = torch.Tensor(outputSize,inputSize)
  6. self:reset()
  7. end
  8. function Cosine:reset(stdv)
  9. if stdv then
  10. stdv = stdv * math.sqrt(3)
  11. else
  12. stdv = 1./math.sqrt(self.weight:size(1))
  13. end
  14. self.weight:uniform(-stdv, stdv)
  15. end
  16. function Cosine:updateOutput(input)
  17. local inputSize = self.weight:size(2)
  18. local outputSize = self.weight:size(1)
  19. self._weightNorm = self._weightNorm or self.weight.new()
  20. self._inputNorm = self._inputNorm or self.weight.new()
  21. -- y_j = (w_j * x) / ( || w_j || * || x || )
  22. self._weightNorm:norm(self.weight,2,2):add(1e-12)
  23. if input:dim() == 1 then
  24. self.output:resize(outputSize):zero()
  25. self.output:addmv(1, self.weight, input)
  26. self.__norm = input:norm()+1e-12
  27. self.output:cdiv(self._weightNorm:view(outputSize)):div(self.__norm)
  28. elseif input:dim() == 2 then
  29. local batchSize = input:size(1)
  30. local nElement = self.output:nElement()
  31. self.output:resize(batchSize, outputSize)
  32. if self.output:nElement() ~= nElement then
  33. self.output:zero()
  34. end
  35. self.output:addmm(0, self.output, 1, input, self.weight:t())
  36. self._inputNorm:norm(input,2,2):add(1e-12)
  37. self.output:cdiv(self._weightNorm:view(1,outputSize):expandAs(self.output))
  38. self.output:cdiv(self._inputNorm:expandAs(self.output))
  39. else
  40. error('input must be vector or matrix')
  41. end
  42. return self.output
  43. end
  44. function Cosine:updateGradInput(input, gradOutput)
  45. if not self.gradInput then
  46. return
  47. end
  48. local inputSize = self.weight:size(2)
  49. local outputSize = self.weight:size(1)
  50. --[[
  51. dy_j w_ji x_i
  52. ---- = ------------------- - y_j ---------
  53. dx_i || w_j || * || x || || x ||^2
  54. --]]
  55. local nElement = self.gradInput:nElement()
  56. self.gradInput:resizeAs(input)
  57. if self.gradInput:nElement() ~= nElement then
  58. self.gradInput:zero()
  59. end
  60. if input:dim() == 1 then
  61. self._weight = self._weight or input.new()
  62. self._weight:resizeAs(self.weight):copy(self.weight)
  63. self._weight:cdiv(self._weightNorm:expandAs(self.weight))
  64. self._weight:div(self.__norm)
  65. self._weight:addr(1, self._weight, -1/(self.__norm*self.__norm), self.output, input)
  66. self.gradInput:addmv(0, 1, self._weight:t(), gradOutput)
  67. elseif input:dim() == 2 then
  68. local inputNorm = self._inputNorm:expandAs(input)
  69. local weightNorm = self._weightNorm:view(1,outputSize):expandAs(gradOutput)
  70. self.gradInput:copy(input):cdiv(inputNorm)
  71. self._gradOutput = self._gradOutput or gradOutput.new()
  72. self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
  73. self._gradOutput:cmul(self.output)
  74. self._sum = self._sum or input.new()
  75. self._sum:sum(self._gradOutput, 2)
  76. self.gradInput:cmul(self._sum:expandAs(input))
  77. self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
  78. self._gradOutput:cdiv(weightNorm)
  79. self.gradInput:addmm(-1, self.gradInput, 1, self._gradOutput, self.weight)
  80. self.gradInput:cdiv(inputNorm)
  81. end
  82. return self.gradInput
  83. end
  84. function Cosine:accGradParameters(input, gradOutput, scale)
  85. scale = scale or 1
  86. local inputSize = self.weight:size(2)
  87. local outputSize = self.weight:size(1)
  88. --[[
  89. dy_j x_i w_ji
  90. ----- = ------------------- - y_j -----------
  91. dw_ji || w_j || * || x || || w_j ||^2
  92. --]]
  93. if input:dim() == 1 then
  94. self._gradOutput = self._gradOutput or gradOutput.new()
  95. self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
  96. local weightNorm = self._weightNorm:view(outputSize)
  97. self._gradOutput:cdiv(weightNorm)
  98. self.gradWeight:addr(scale/self.__norm, self._gradOutput, input)
  99. self._gradOutput:cdiv(weightNorm)
  100. self._gradOutput:cmul(self.output)
  101. self._weight = self._weight or self.weight.new()
  102. self._weight:resizeAs(self._weight):copy(self.weight)
  103. self._weight:cmul(self._gradOutput:view(outputSize, 1):expandAs(self.weight))
  104. self.gradWeight:add(-1, self._weight)
  105. elseif input:dim() == 2 then
  106. self._weight = self._weight or self.weight.new()
  107. self._weight:resizeAs(self.weight):copy(self.weight)
  108. self._gradOutput = self._gradOutput or gradOutput.new()
  109. self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
  110. self._gradOutput:cmul(self.output)
  111. self._sum = self._sum or input.new()
  112. self._sum:sum(self._gradOutput, 1)
  113. local grad = self._sum[1]
  114. grad:cdiv(self._weightNorm:select(2,1))
  115. self._weight:cmul(grad:view(outputSize,1):expandAs(self._weight))
  116. local input_ = self._gradOutput
  117. input_:resizeAs(input):copy(input)
  118. input_:cdiv(self._inputNorm:expandAs(input))
  119. self._weight:addmm(-1, self._weight, 1, gradOutput:t(), input_)
  120. self._weight:cdiv(self._weightNorm:expandAs(self._weight))
  121. self.gradWeight:add(self._weight)
  122. else
  123. error"1D or 2D input expected"
  124. end
  125. end
  126. function Cosine:type(type, tensorCache)
  127. if type then
  128. -- prevent premature memory allocations
  129. self._input = nil
  130. self._weight = nil
  131. self._inputNorm = nil
  132. self._weightNorm = nil
  133. self._gradOutput = nil
  134. self._sum = nil
  135. end
  136. return parent.type(self, type, tensorCache)
  137. end
  138. function Cosine:clearState()
  139. nn.utils.clear(self, {
  140. '_input',
  141. '_weight',
  142. '_gradOutput',
  143. '_sum',
  144. '_inputNorm',
  145. '_weightNorm',
  146. })
  147. return parent.clearState(self)
  148. end