aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/nn/LayerNormalization.lua
blob: 722d7c8020f2487c6b83813b571474b2910bbb06 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
-- Reference: https://arxiv.org/pdf/1607.06450.pdf (Section 3)

local LayerNormalization, parent = torch.class('nn.LayerNormalization', 'nn.Sequential')
function LayerNormalization:__init(nOutput, bias, eps, affine)
   parent.__init(self)
   eps = eps or 1e-10
   affine = (affine == nil) and true or affine
   bias = bias or 0

   self:add(nn.ConcatTable()
               :add(nn.Identity())
               :add(nn.Sequential()
                       :add(nn.Mean(1, 1))
                       :add(nn.Replicate(nOutput,1,1))))
      :add(nn.CSubTable())
      :add(nn.Normalize(2, eps))
      :add(nn.MulConstant(torch.sqrt(nOutput)))

   if affine then
      local biasTransform = nn.Add(nOutput, false)
      biasTransform.bias:fill(bias)
      local gainTransform = nn.CMul(nOutput)
      gainTransform.weight:fill(1.)
      self:add(gainTransform)
      self:add(biasTransform)
   end
end