aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/MapTable.lua
blob: c79f1ea1d8a2fd13b7f237134fd1c0b8b3914b26 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
local MapTable, parent = torch.class('nn.MapTable', 'nn.Container')

function MapTable:__init(module, shared)
   parent.__init(self)
   self.shared = (shared == nil) and true or shared
   self.sharedparams = {'weight', 'bias', 'gradWeight', 'gradBias'}
   self.output = {}
   self.gradInput = {}
   self:add(module)
end

function MapTable:_extend(n)
   self.sharedparams = self.sharedparams or {'weight', 'bias', 'gradWeight', 'gradBias'}
   self.modules[1] = self.module
   for i = 2, n do
      if not self.modules[i] then
         if self.shared then
           self.modules[i] = self.module:clone(table.unpack(self.sharedparams))
         else
           self.modules[i] = self.module:clone()
         end
      end
   end
end

function MapTable:resize(n)
   self:_extend(n)
   for i = n + 1, #self.modules do
      -- It's not clear why this clearState call is necessary, but it fixes
      -- https://github.com/torch/nn/issues/1141 .
      self.modules[i]:clearState()
      self.modules[i] = nil
   end
end

function MapTable:add(module)
   assert(not self.module, 'Single module required')
   self.module = module
   self.modules[1] = self.module
   return self
end

function MapTable:updateOutput(input)
   self.output = {}
   self:_extend(#input)
   for i = 1, #input do
      self.output[i] = self:rethrowErrors(self.modules[i], i, 'updateOutput', input[i])
   end
   return self.output
end

function MapTable:updateGradInput(input, gradOutput)
   self.gradInput = {}
   self:_extend(#input)
   for i = 1, #input do
      self.gradInput[i] = self:rethrowErrors(self.modules[i], i, 'updateGradInput', input[i], gradOutput[i])
   end
   return self.gradInput
end

function MapTable:accGradParameters(input, gradOutput, scale)
   scale = scale or 1
   self:_extend(#input)
   for i = 1, #input do
      self:rethrowErrors(self.modules[i], i, 'accGradParameters', input[i], gradOutput[i], scale)
   end
end

function MapTable:accUpdateGradParameters(input, gradOutput, lr)
   lr = lr or 1
   self:_extend(#input)
   for i = 1, #input do
      self:rethrowErrors(self.modules[i], i, 'accUpdateGradParameters', input[i], gradOutput[i], lr)
   end
end

function MapTable:zeroGradParameters()
    if self.module then
        if self.shared then
          self.module:zeroGradParameters()
        else
          parent.zeroGradParameters(self)
        end
    end
end

function MapTable:updateParameters(learningRate)
    if self.module then
        if self.shared then
          self.module:updateParameters(learningRate)
        else
          parent.updateParameters(self, learningRate)
        end
    end
end

function MapTable:clearState()
   for i = 2, #self.modules do
      -- It's not clear why this clearState call is necessary, but it fixes
      -- https://github.com/torch/nn/issues/1141 .
      self.modules[i]:clearState()
      self.modules[i] = nil
   end
   parent.clearState(self)
end

function MapTable:__tostring__()
   local tab = '  '
   local line = '\n'
   local extlast = '      '
   local str = torch.type(self)
   if self.module then
      str = str .. ' {' .. line .. tab
      str = str .. tostring(self.module):gsub(line, line .. tab .. extlast) .. line .. '}'
   else
      str = str .. ' { }'
   end
   return str
end