aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Bottle.lua
blob: 6dee432f5ce38e2d7e46a4b97beb0f4e54776f0f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
local Bottle, parent = torch.class("nn.Bottle", "nn.Decorator")
local unpack = unpack or table.unpack

function Bottle:__init(module, nInputDim, nOutputDim)
   parent.__init(self, module)
   self.nInputDim = nInputDim or 2
   self.nOutputDim = nOutputDim or self.nInputDim
   self.dimDelta = self.nInputDim - self.nOutputDim
   -- Used to reshape the gradients
   self.inShape = torch.Tensor(self.nInputDim)
   self.outShape = torch.Tensor(self.nOutputDim)
end

function Bottle:updateOutput(input)
   -- first batchDims dimensions will be fused
   local batchDims = input:dim() - self.nInputDim + 1
   -- see if bottle is required
   if batchDims > 1 then
      -- bottle the first dims
      local inSize = torch.LongTensor(input:size())
      local squeezeSize = inSize[{{1, batchDims - 1}}]:prod()
      self.inShape:copy(inSize[{{batchDims, input:dim()}}])
      self.inShape[{{1}}]:mul(squeezeSize)
      -- Forward with the module's dimension
      local newInput = input:view(unpack(self.inShape:totable()))
      local output = self.modules[1]:updateOutput(newInput)
      assert(output:dim() == self.nOutputDim,
	     "Wrong number of output dims on module. Expected: " ..
		self.nOutputDim .. ' but got ' ..
		tostring(output and output:dim()))
      self.outShape:copy(torch.LongTensor(output:size()))
      if math.abs(self.dimDelta) > 0 then
         inSize:resize(inSize:size(1) - self.dimDelta)
      end
      inSize[{{batchDims, inSize:size(1)}}]:copy(self.outShape)
      inSize[{{batchDims}}]:div(squeezeSize)
      -- unbottle
      self.output:set(output:view(unpack(torch.totable(inSize))))
   else
      self.output:set(self.modules[1]:updateOutput(input))
   end
   return self.output
end

function Bottle:updateGradInput(input, gradOutput)
   if input:dim() > self.nInputDim then
      local input_ = input:view(unpack(self.inShape:totable()))
      local gradOutput_ = gradOutput:view(unpack(self.outShape:totable()))
      self.modules[1]:updateGradInput(input_, gradOutput_)
      if self.modules[1].gradInput then
         self.gradInput:set(self.modules[1].gradInput:viewAs(input))
      else
         self.gradInput = nil
      end
   else
      if self.modules[1].gradInput then
         self.gradInput:set(self.modules[1]:updateGradInput(input, gradOutput))
      else
         self.gradInput = nil
      end
   end
   return self.gradInput
end

function Bottle:accGradParameters(input, gradOutput, scale)
   if input:dim() > self.nInputDim then
      input = input:view(unpack(self.inShape:totable()))
      gradOutput = gradOutput:view(unpack(self.outShape:totable()))
   end
   self.modules[1]:accGradParameters(input, gradOutput, scale)
end