You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchNormalization.lua 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. --[[
  2. This file implements Batch Normalization as described in the paper:
  3. "Batch Normalization: Accelerating Deep Network Training
  4. by Reducing Internal Covariate Shift"
  5. by Sergey Ioffe, Christian Szegedy
  6. This implementation is useful for inputs NOT coming from convolution layers.
  7. For convolution layers, use nn.SpatialBatchNormalization.
  8. The operation implemented is:
  9. y = ( x - mean(x) )
  10. -------------------- * gamma + beta
  11. standard-deviation(x)
  12. where gamma and beta are learnable parameters.
  13. The learning of gamma and beta is optional.
  14. Usage:
  15. with learnable parameters: nn.BatchNormalization(N [,eps] [,momentum])
  16. where N = dimensionality of input
  17. without learnable parameters: nn.BatchNormalization(N [,eps] [,momentum], false)
  18. eps is a small value added to the standard-deviation to avoid divide-by-zero.
  19. Defaults to 1e-5
  20. In training time, this layer keeps a running estimate of it's computed mean and std.
  21. The running sum is kept with a default momentum of 0.1 (unless over-ridden)
  22. In test time, this running mean/std is used to normalize.
  23. ]]--
  24. local BN,parent = torch.class('nn.BatchNormalization', 'nn.Module')
  25. local THNN = require 'nn.THNN'
  26. BN.__version = 2
  27. -- expected dimension of input
  28. BN.nDim = 2
  29. function BN:__init(nOutput, eps, momentum, affine)
  30. parent.__init(self)
  31. assert(nOutput and type(nOutput) == 'number',
  32. 'Missing argument #1: dimensionality of input. ')
  33. assert(nOutput ~= 0, 'To set affine=false call BatchNormalization'
  34. .. '(nOutput, eps, momentum, false) ')
  35. if affine ~= nil then
  36. assert(type(affine) == 'boolean', 'affine has to be true/false')
  37. self.affine = affine
  38. else
  39. self.affine = true
  40. end
  41. self.eps = eps or 1e-5
  42. self.train = true
  43. self.momentum = momentum or 0.1
  44. self.running_mean = torch.zeros(nOutput)
  45. self.running_var = torch.ones(nOutput)
  46. if self.affine then
  47. self.weight = torch.Tensor(nOutput)
  48. self.bias = torch.Tensor(nOutput)
  49. self.gradWeight = torch.Tensor(nOutput)
  50. self.gradBias = torch.Tensor(nOutput)
  51. self:reset()
  52. end
  53. end
  54. function BN:reset()
  55. if self.weight then
  56. self.weight:uniform()
  57. end
  58. if self.bias then
  59. self.bias:zero()
  60. end
  61. self.running_mean:zero()
  62. self.running_var:fill(1)
  63. end
  64. function BN:checkInputDim(input)
  65. local iDim = input:dim()
  66. assert(iDim == self.nDim or
  67. (iDim == self.nDim - 1 and self.train == false), string.format(
  68. 'only mini-batch supported (%dD tensor), got %dD tensor instead',
  69. self.nDim, iDim))
  70. local featDim = (iDim == self.nDim - 1) and 1 or 2
  71. assert(input:size(featDim) == self.running_mean:nElement(), string.format(
  72. 'got %d-feature tensor, expected %d',
  73. input:size(featDim), self.running_mean:nElement()))
  74. end
  75. local function makeContiguous(self, input, gradOutput)
  76. if not input:isContiguous() then
  77. self._input = self._input or input.new()
  78. self._input:resizeAs(input):copy(input)
  79. input = self._input
  80. end
  81. if gradOutput then
  82. if not gradOutput:isContiguous() then
  83. self._gradOutput = self._gradOutput or gradOutput.new()
  84. self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
  85. gradOutput = self._gradOutput
  86. end
  87. end
  88. return input, gradOutput
  89. end
  90. local function makeBatch(self, input)
  91. local iDim = input:dim()
  92. if self.train == false and iDim == self.nDim - 1 then
  93. return nn.utils.addSingletonDimension(input, input, 1)
  94. else
  95. return input
  96. end
  97. end
  98. function BN:updateOutput(input)
  99. self:checkInputDim(input)
  100. input = makeContiguous(self, input)
  101. input = makeBatch(self, input)
  102. self.save_mean = self.save_mean or input.new()
  103. self.save_mean:resizeAs(self.running_mean)
  104. self.save_std = self.save_std or input.new()
  105. self.save_std:resizeAs(self.running_var)
  106. input.THNN.BatchNormalization_updateOutput(
  107. input:cdata(),
  108. self.output:cdata(),
  109. THNN.optionalTensor(self.weight),
  110. THNN.optionalTensor(self.bias),
  111. self.running_mean:cdata(),
  112. self.running_var:cdata(),
  113. self.save_mean:cdata(),
  114. self.save_std:cdata(),
  115. self.train,
  116. self.momentum,
  117. self.eps)
  118. return self.output
  119. end
  120. local function backward(self, input, gradOutput, scale, gradInput, gradWeight, gradBias)
  121. self:checkInputDim(input)
  122. self:checkInputDim(gradOutput)
  123. assert(self.save_mean and self.save_std, 'must call :updateOutput() first')
  124. input, gradOutput = makeContiguous(self, input, gradOutput)
  125. input = makeBatch(self, input)
  126. gradOutput = makeBatch(self, gradOutput)
  127. scale = scale or 1
  128. if gradInput then
  129. gradInput:resizeAs(gradOutput)
  130. end
  131. input.THNN.BatchNormalization_backward(
  132. input:cdata(),
  133. gradOutput:cdata(),
  134. THNN.optionalTensor(gradInput),
  135. THNN.optionalTensor(gradWeight),
  136. THNN.optionalTensor(gradBias),
  137. THNN.optionalTensor(self.weight),
  138. self.running_mean:cdata(),
  139. self.running_var:cdata(),
  140. self.save_mean:cdata(),
  141. self.save_std:cdata(),
  142. self.train,
  143. scale,
  144. self.eps)
  145. return self.gradInput
  146. end
  147. function BN:backward(input, gradOutput, scale)
  148. return backward(self, input, gradOutput, scale, self.gradInput, self.gradWeight, self.gradBias)
  149. end
  150. function BN:updateGradInput(input, gradOutput)
  151. return backward(self, input, gradOutput, 1, self.gradInput)
  152. end
  153. function BN:accGradParameters(input, gradOutput, scale)
  154. return backward(self, input, gradOutput, scale, nil, self.gradWeight, self.gradBias)
  155. end
  156. function BN:read(file, version)
  157. parent.read(self, file)
  158. if version < 2 then
  159. if self.running_std then
  160. self.running_var = self.running_std:pow(-2):add(-self.eps)
  161. self.running_std = nil
  162. end
  163. end
  164. end
  165. function BN:clearState()
  166. -- first 5 buffers are not present in the current implementation,
  167. -- but we keep them for cleaning old saved models
  168. nn.utils.clear(self, {
  169. 'buffer',
  170. 'buffer2',
  171. 'centered',
  172. 'std',
  173. 'normalized',
  174. '_input',
  175. '_gradOutput',
  176. 'save_mean',
  177. 'save_std',
  178. })
  179. return parent.clearState(self)
  180. end
  181. function BN:__tostring__()
  182. return string.format('%s (%dD) (%d)', torch.type(self), self.nDim, self.running_mean:nElement())
  183. end