You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Bilinear.lua 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. local Bilinear, parent = torch.class('nn.Bilinear', 'nn.Module')
  2. local function isint(x) return type(x) == 'number' and x == math.floor(x) end
  3. function Bilinear:__assertInput(input)
  4. assert(input and type(input) == 'table' and #input == 2,
  5. 'input should be a table containing two data Tensors')
  6. assert(input[1]:nDimension() == 2 and input[2]:nDimension() == 2,
  7. 'input Tensors should be two-dimensional')
  8. assert(input[1]:size(1) == input[2]:size(1),
  9. 'input Tensors should have the same number of rows (instances)')
  10. assert(input[1]:size(2) == self.weight:size(2),
  11. 'dimensionality of first input is erroneous')
  12. assert(input[2]:size(2) == self.weight:size(3),
  13. 'dimensionality of second input is erroneous')
  14. end
  15. function Bilinear:__assertInputGradOutput(input, gradOutput)
  16. assert(input[1]:size(1) == gradOutput:size(1),
  17. 'number of rows in gradOutput does not match input')
  18. assert(gradOutput:size(2) == self.weight:size(1),
  19. 'number of columns in gradOutput does not output size of layer')
  20. end
  21. function Bilinear:__init(inputSize1, inputSize2, outputSize, bias)
  22. -- assertions:
  23. assert(self and inputSize1 and inputSize2 and outputSize,
  24. 'should specify inputSize1 and inputSize2 and outputSize')
  25. assert(isint(inputSize1) and isint(inputSize2) and isint(outputSize),
  26. 'inputSize1 and inputSize2 and outputSize should be integer numbers')
  27. assert(inputSize1 > 0 and inputSize2 > 0 and outputSize > 0,
  28. 'inputSize1 and inputSize2 and outputSize should be positive numbers')
  29. -- set up model:
  30. parent.__init(self)
  31. local bias = ((bias == nil) and true) or bias
  32. self.weight = torch.Tensor(outputSize, inputSize1, inputSize2)
  33. self.gradWeight = torch.Tensor(outputSize, inputSize1, inputSize2)
  34. if bias then
  35. self.bias = torch.Tensor(outputSize)
  36. self.gradBias = torch.Tensor(outputSize)
  37. end
  38. self.gradInput = {torch.Tensor(), torch.Tensor()}
  39. self:reset()
  40. end
  41. function Bilinear:reset(stdv)
  42. assert(self)
  43. if stdv then
  44. assert(stdv and type(stdv) == 'number' and stdv > 0,
  45. 'standard deviation should be a positive number')
  46. stdv = stdv * math.sqrt(3)
  47. else
  48. stdv = 1 / math.sqrt(self.weight:size(2))
  49. end
  50. self.weight:uniform(-stdv, stdv)
  51. if self.bias then self.bias:uniform(-stdv, stdv) end
  52. return self
  53. end
  54. function Bilinear:updateOutput(input)
  55. assert(self)
  56. self:__assertInput(input)
  57. -- set up buffer:
  58. self.buff2 = self.buff2 or input[1].new()
  59. self.buff2:resizeAs(input[2])
  60. -- compute output scores:
  61. self.output:resize(input[1]:size(1), self.weight:size(1))
  62. for k = 1,self.weight:size(1) do
  63. torch.mm(self.buff2, input[1], self.weight[k])
  64. self.buff2:cmul(input[2])
  65. torch.sum(self.output:narrow(2, k, 1), self.buff2, 2)
  66. end
  67. if self.bias then
  68. self.output:add(
  69. self.bias:reshape(1, self.bias:nElement()):expandAs(self.output)
  70. )
  71. end
  72. return self.output
  73. end
  74. function Bilinear:updateGradInput(input, gradOutput)
  75. assert(self)
  76. if self.gradInput then
  77. self:__assertInputGradOutput(input, gradOutput)
  78. if #self.gradInput == 0 then
  79. for i = 1, 2 do self.gradInput[i] = input[1].new() end
  80. end
  81. -- compute d output / d input:
  82. self.gradInput[1]:resizeAs(input[1]):fill(0)
  83. self.gradInput[2]:resizeAs(input[2]):fill(0)
  84. -- do first slice of weight tensor (k = 1)
  85. self.gradInput[1]:mm(input[2], self.weight[1]:t())
  86. self.gradInput[1]:cmul(gradOutput:narrow(2,1,1):expand(self.gradInput[1]:size(1),
  87. self.gradInput[1]:size(2)))
  88. self.gradInput[2]:addmm(1, input[1], self.weight[1])
  89. self.gradInput[2]:cmul(gradOutput:narrow(2,1,1):expand(self.gradInput[2]:size(1),
  90. self.gradInput[2]:size(2)))
  91. -- do remaining slices of weight tensor
  92. if self.weight:size(1) > 1 then
  93. self.buff1 = self.buff1 or input[1].new()
  94. self.buff1:resizeAs(input[1])
  95. for k = 2, self.weight:size(1) do
  96. self.buff1:mm(input[2], self.weight[k]:t())
  97. self.buff1:cmul(gradOutput:narrow(2,k,1):expand(self.gradInput[1]:size(1),
  98. self.gradInput[1]:size(2)))
  99. self.gradInput[1]:add(self.buff1)
  100. self.buff2:mm(input[1], self.weight[k])
  101. self.buff2:cmul(gradOutput:narrow(2,k,1):expand(self.gradInput[2]:size(1),
  102. self.gradInput[2]:size(2)))
  103. self.gradInput[2]:add(self.buff2)
  104. end
  105. end
  106. return self.gradInput
  107. end
  108. end
  109. function Bilinear:accGradParameters(input, gradOutput, scale)
  110. local scale = scale or 1
  111. self:__assertInputGradOutput(input, gradOutput)
  112. assert(scale and type(scale) == 'number' and scale >= 0)
  113. -- make sure we have buffer:
  114. self.buff1 = self.buff1 or input[1].new()
  115. self.buff1:resizeAs(input[1])
  116. -- accumulate parameter gradients:
  117. for k = 1,self.weight:size(1) do
  118. torch.cmul(
  119. self.buff1, input[1], gradOutput:narrow(2, k, 1):expandAs(input[1])
  120. )
  121. self.gradWeight[k]:addmm(self.buff1:t(), input[2])
  122. end
  123. if self.bias then self.gradBias:add(scale, gradOutput:sum(1)) end
  124. end
  125. function Bilinear:sharedAccUpdateGradParameters(input, gradOutput, lr)
  126. -- we do not need to accumulate parameters when sharing:
  127. self:defaultAccUpdateGradParameters(input, gradOutput, lr)
  128. end
  129. function Bilinear:__tostring__()
  130. return torch.type(self) ..
  131. string.format(
  132. '(%dx%d -> %d) %s',
  133. self.weight:size(2), self.weight:size(3), self.weight:size(1),
  134. (self.bias == nil and ' without bias' or '')
  135. )
  136. end
  137. function Bilinear:clearState()
  138. if self.buff2 then self.buff2:set() end
  139. if self.buff1 then self.buff1:set() end
  140. return parent.clearState(self)
  141. end