You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CosineDistance.lua 2.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. local CosineDistance, parent = torch.class('nn.CosineDistance', 'nn.Module')
  2. function CosineDistance:__init()
  3. parent.__init(self)
  4. self.gradInput = {torch.Tensor(), torch.Tensor()}
  5. end
  6. local function makeContiguous(self, input1, input2)
  7. if not input1:isContiguous() then
  8. self._input1 = self._input1 or input1.new()
  9. self._input1:resizeAs(input1):copy(input1)
  10. input1 = self._input1
  11. end
  12. if not input2:isContiguous() then
  13. self._input2 = self._input2 or input2.new()
  14. self._input2:resizeAs(input2):copy(input2)
  15. input2 = self._input2
  16. end
  17. return input1, input2
  18. end
  19. function CosineDistance:updateOutput(input)
  20. local input1, input2 = input[1], input[2]
  21. input1, input2 = makeContiguous(self, input1, input2)
  22. if input1:dim() == 1 then
  23. input1 = input1:view(1,-1)
  24. input2 = input2:view(1,-1)
  25. end
  26. if not self.buffer then
  27. self.buffer = input1.new()
  28. self.w1 = input1.new()
  29. self.w22 = input1.new()
  30. self.w = input1.new()
  31. self.w32 = input1.new()
  32. self.ones = input1.new()
  33. end
  34. self.buffer:cmul(input1,input2)
  35. self.w1:sum(self.buffer,2)
  36. local epsilon = 1e-12
  37. self.buffer:cmul(input1,input1)
  38. self.w22:sum(self.buffer,2):add(epsilon)
  39. self.ones:resizeAs(self.w22):fill(1)
  40. self.w22:cdiv(self.ones, self.w22)
  41. self.w:resizeAs(self.w22):copy(self.w22)
  42. self.buffer:cmul(input2,input2)
  43. self.w32:sum(self.buffer,2):add(epsilon)
  44. self.w32:cdiv(self.ones, self.w32)
  45. self.w:cmul(self.w32)
  46. self.w:sqrt()
  47. self.output:cmul(self.w1,self.w)
  48. self.output:resize(input1:size(1))
  49. return self.output
  50. end
  51. function CosineDistance:updateGradInput(input, gradOutput)
  52. local v1 = input[1]
  53. local v2 = input[2]
  54. local not_batch = false
  55. v1, v2 = makeContiguous(self, v1, v2)
  56. if v1:dim() == 1 then
  57. v1 = v1:view(1,-1)
  58. v2 = v2:view(1,-1)
  59. not_batch = true
  60. end
  61. if #self.gradInput ~= 2 then
  62. self.gradInput[1] = self.gradInput[1] or v1.new()
  63. self.gradInput[2] = self.gradInput[2] or v1.new()
  64. end
  65. local gw1 = self.gradInput[1]
  66. local gw2 = self.gradInput[2]
  67. gw1:resizeAs(v1):copy(v2)
  68. gw2:resizeAs(v1):copy(v1)
  69. self.buffer:cmul(self.w1,self.w22)
  70. gw1:addcmul(-1,self.buffer:expandAs(v1),v1)
  71. gw1:cmul(self.w:expandAs(v1))
  72. self.buffer:cmul(self.w1,self.w32)
  73. gw2:addcmul(-1,self.buffer:expandAs(v1),v2)
  74. gw2:cmul(self.w:expandAs(v1))
  75. local go = gradOutput:view(-1,1):expandAs(v1)
  76. gw1:cmul(go)
  77. gw2:cmul(go)
  78. if not_batch then
  79. self.gradInput[1]:resize(gw1:size(2))
  80. self.gradInput[2]:resize(gw2:size(2))
  81. end
  82. return self.gradInput
  83. end
  84. function CosineDistance:clearState()
  85. nn.utils.clear(self, {
  86. 'buffer',
  87. 'w1',
  88. 'w22',
  89. 'w',
  90. 'w32',
  91. 'ones',
  92. })
  93. return parent.clearState(self)
  94. end