blob: 722d7c8020f2487c6b83813b571474b2910bbb06 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
-- Reference: https://arxiv.org/pdf/1607.06450.pdf (Section 3)
local LayerNormalization, parent = torch.class('nn.LayerNormalization', 'nn.Sequential')
function LayerNormalization:__init(nOutput, bias, eps, affine)
parent.__init(self)
eps = eps or 1e-10
affine = (affine == nil) and true or affine
bias = bias or 0
self:add(nn.ConcatTable()
:add(nn.Identity())
:add(nn.Sequential()
:add(nn.Mean(1, 1))
:add(nn.Replicate(nOutput,1,1))))
:add(nn.CSubTable())
:add(nn.Normalize(2, eps))
:add(nn.MulConstant(torch.sqrt(nOutput)))
if affine then
local biasTransform = nn.Add(nOutput, false)
biasTransform.bias:fill(bias)
local gainTransform = nn.CMul(nOutput)
gainTransform.weight:fill(1.)
self:add(gainTransform)
self:add(biasTransform)
end
end
|