aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Concat.lua
blob: d7e3ee711f5d2ef60fae4196475a9cd3e3ca08e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
local Concat, parent = torch.class('nn.Concat', 'nn.Container')

function Concat:__init(dimension)
   parent.__init(self)
   self.outputSize = torch.LongStorage()
   self.dimension = dimension
end

function Concat:updateOutput(input)
   self.outputSize = self.outputSize or torch.LongStorage()

   local outs = {}
   for i=1,#self.modules do
      local currentOutput = self:rethrowErrors(self.modules[i], i, 'updateOutput', input)
      outs[i] = currentOutput
      if i == 1 then
         self.outputSize:resize(currentOutput:dim()):copy(currentOutput:size())
      else
         self.outputSize[self.dimension] = self.outputSize[self.dimension] + currentOutput:size(self.dimension)
      end
   end
   self.output:resize(self.outputSize)

   local offset = 1
   for i,module in ipairs(self.modules) do
      local currentOutput = outs[i]
      self.output:narrow(self.dimension, offset, currentOutput:size(self.dimension)):copy(currentOutput)
      offset = offset + currentOutput:size(self.dimension)
   end
   return self.output
end

local function retable(t1, t2, f)
   for k, v in ipairs(t2) do
      if (torch.type(v) == "table") then
         t1[k] = retable(t1[k] or {}, t2[k], f)
      else
         f(t1, k, v)
      end
   end
   for i=#t2+1, #t1 do
      t1[i] = nil
   end
   return t1
end

local function backward(self, method, input, gradOutput, scale)
   local isTable = torch.type(input) == 'table'
   local wasTable = torch.type(self.gradInput) == 'table'
   scale = scale or 1

   if isTable then
      local offset = 1
      for i,module in ipairs(self.modules) do
         local currentOutput = module.output
         local currentGradInput = self:rethrowErrors(module, i, method, input,
                                                     gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), scale)
         if torch.type(currentGradInput) ~= 'table' then
            error"currentGradInput is not a table!"
         end
         if #input ~= #currentGradInput then
            error("table size mismatch: "..#input.." ~= "..#currentGradInput)
         end
         if i == 1 then
            self.gradInput = wasTable and self.gradInput or {}
            retable(self.gradInput, currentGradInput,
                    function(t, k, v)
                       t[k] = t[k] or v:clone()
                       t[k]:resizeAs(v)
                       t[k]:copy(v)
                    end
            )
         else
            retable(self.gradInput, currentGradInput,
                    function(t, k, v)
                       if t[k] then
                          t[k]:add(v)
                       else
                          t[k] = v:clone()
                       end
                    end
            )
         end
         offset = offset + currentOutput:size(self.dimension)
      end
   else
      self.gradInput = (not wasTable) and self.gradInput:resizeAs(input) or input:clone()
      local offset = 1
      for i,module in ipairs(self.modules) do
         local currentOutput = module.output
         local currentGradInput = self:rethrowErrors(module, i, method, input,
                                                     gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), scale)
         if currentGradInput then -- if the module does not produce a gradInput (for example first layer), then ignore it and move on.
            if i==1 then
               self.gradInput:copy(currentGradInput)
            else
               self.gradInput:add(currentGradInput)
            end
         end
         offset = offset + currentOutput:size(self.dimension)
      end
   end
   return self.gradInput
end

function Concat:updateGradInput(input, gradOutput)
   return backward(self, 'updateGradInput', input, gradOutput)
end

function Concat:backward(input, gradOutput, scale)
   return backward(self, 'backward', input, gradOutput, scale)
end

function Concat:accGradParameters(input, gradOutput, scale)
   scale = scale or 1
   local offset = 1
   for i,module in ipairs(self.modules) do
      local currentOutput = module.output
      self:rethrowErrors(module, i, 'accGradParameters', input,
          gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)),
          scale)
      offset = offset + currentOutput:size(self.dimension)
   end
end

function Concat:accUpdateGradParameters(input, gradOutput, lr)
   local offset = 1
   for i,module in ipairs(self.modules) do
      local currentOutput = module.output
      self:rethrowErrors(module, i, 'accUpdateGradParameters',
          input,
          gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)),
          lr)
      offset = offset + currentOutput:size(self.dimension)
   end
end

function Concat:__tostring__()
   local tab = '  '
   local line = '\n'
   local next = '  |`-> '
   local lastNext = '   `-> '
   local ext = '  |    '
   local extlast = '       '
   local last = '   ... -> '
   local str = torch.type(self)
   str = str .. ' {' .. line .. tab .. 'input'
   for i=1,#self.modules do
      if i == #self.modules then
         str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
      else
         str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
      end
   end
   str = str .. line .. tab .. last .. 'output'
   str = str .. line .. '}'
   return str
end