aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Bottle.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/Bottle.lua')
-rw-r--r--contrib/lua-torch/nn/Bottle.lua71
1 files changed, 71 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Bottle.lua b/contrib/lua-torch/nn/Bottle.lua
new file mode 100644
index 000000000..6dee432f5
--- /dev/null
+++ b/contrib/lua-torch/nn/Bottle.lua
@@ -0,0 +1,71 @@
+local Bottle, parent = torch.class("nn.Bottle", "nn.Decorator")
+local unpack = unpack or table.unpack
+
+function Bottle:__init(module, nInputDim, nOutputDim)
+ parent.__init(self, module)
+ self.nInputDim = nInputDim or 2
+ self.nOutputDim = nOutputDim or self.nInputDim
+ self.dimDelta = self.nInputDim - self.nOutputDim
+ -- Used to reshape the gradients
+ self.inShape = torch.Tensor(self.nInputDim)
+ self.outShape = torch.Tensor(self.nOutputDim)
+end
+
+function Bottle:updateOutput(input)
+ -- first batchDims dimensions will be fused
+ local batchDims = input:dim() - self.nInputDim + 1
+ -- see if bottle is required
+ if batchDims > 1 then
+ -- bottle the first dims
+ local inSize = torch.LongTensor(input:size())
+ local squeezeSize = inSize[{{1, batchDims - 1}}]:prod()
+ self.inShape:copy(inSize[{{batchDims, input:dim()}}])
+ self.inShape[{{1}}]:mul(squeezeSize)
+ -- Forward with the module's dimension
+ local newInput = input:view(unpack(self.inShape:totable()))
+ local output = self.modules[1]:updateOutput(newInput)
+ assert(output:dim() == self.nOutputDim,
+ "Wrong number of output dims on module. Expected: " ..
+ self.nOutputDim .. ' but got ' ..
+ tostring(output and output:dim()))
+ self.outShape:copy(torch.LongTensor(output:size()))
+ if math.abs(self.dimDelta) > 0 then
+ inSize:resize(inSize:size(1) - self.dimDelta)
+ end
+ inSize[{{batchDims, inSize:size(1)}}]:copy(self.outShape)
+ inSize[{{batchDims}}]:div(squeezeSize)
+ -- unbottle
+ self.output:set(output:view(unpack(torch.totable(inSize))))
+ else
+ self.output:set(self.modules[1]:updateOutput(input))
+ end
+ return self.output
+end
+
+function Bottle:updateGradInput(input, gradOutput)
+ if input:dim() > self.nInputDim then
+ local input_ = input:view(unpack(self.inShape:totable()))
+ local gradOutput_ = gradOutput:view(unpack(self.outShape:totable()))
+ self.modules[1]:updateGradInput(input_, gradOutput_)
+ if self.modules[1].gradInput then
+ self.gradInput:set(self.modules[1].gradInput:viewAs(input))
+ else
+ self.gradInput = nil
+ end
+ else
+ if self.modules[1].gradInput then
+ self.gradInput:set(self.modules[1]:updateGradInput(input, gradOutput))
+ else
+ self.gradInput = nil
+ end
+ end
+ return self.gradInput
+end
+
+function Bottle:accGradParameters(input, gradOutput, scale)
+ if input:dim() > self.nInputDim then
+ input = input:view(unpack(self.inShape:totable()))
+ gradOutput = gradOutput:view(unpack(self.outShape:totable()))
+ end
+ self.modules[1]:accGradParameters(input, gradOutput, scale)
+end