aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Squeeze.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/Squeeze.lua')
-rw-r--r--contrib/lua-torch/nn/Squeeze.lua40
1 files changed, 40 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Squeeze.lua b/contrib/lua-torch/nn/Squeeze.lua
new file mode 100644
index 000000000..7d204a19d
--- /dev/null
+++ b/contrib/lua-torch/nn/Squeeze.lua
@@ -0,0 +1,40 @@
+local Squeeze, parent = torch.class('nn.Squeeze', 'nn.Module')
+
+function Squeeze:__init(dim, numInputDims)
+ parent.__init(self)
+ self.dim = dim
+ self:setNumInputDims(numInputDims)
+end
+
+function Squeeze:setNumInputDims(numInputDims)
+ self.numInputDims = numInputDims
+ return self
+end
+
+function Squeeze:updateOutput(input)
+ assert(input and torch.isTensor(input), 'Squeeze only works on tensors')
+ local dim = self.dim
+ local addone = false
+ if self.numInputDims and input:dim()==(self.numInputDims+1) then
+ if dim then
+ dim = dim + 1
+ elseif input:size(1) == 1 then
+ addone = true -- in case of minibatch of size 1.
+ end
+ end
+ self.output:set(dim and input:squeeze(dim) or input:squeeze())
+ if addone then
+ local s = self.output:size():totable{}
+ table.insert(s, 1, 1)
+ self.output:set(self.output:view(torch.LongStorage(s)))
+ end
+ return self.output
+end
+
+function Squeeze:updateGradInput(input, gradOutput)
+ assert(input and torch.isTensor(input), 'Squeeze only works on tensors')
+ assert(gradOutput and torch.isTensor(gradOutput), 'Squeeze only works on tensors')
+ assert(input:nElement() == gradOutput:nElement())
+ self.gradInput:set(gradOutput:view(input:size()))
+ return self.gradInput
+end