diff options
Diffstat (limited to 'contrib/lua-torch/nn/Squeeze.lua')
-rw-r--r-- | contrib/lua-torch/nn/Squeeze.lua | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Squeeze.lua b/contrib/lua-torch/nn/Squeeze.lua new file mode 100644 index 000000000..7d204a19d --- /dev/null +++ b/contrib/lua-torch/nn/Squeeze.lua @@ -0,0 +1,40 @@ +local Squeeze, parent = torch.class('nn.Squeeze', 'nn.Module') + +function Squeeze:__init(dim, numInputDims) + parent.__init(self) + self.dim = dim + self:setNumInputDims(numInputDims) +end + +function Squeeze:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self +end + +function Squeeze:updateOutput(input) + assert(input and torch.isTensor(input), 'Squeeze only works on tensors') + local dim = self.dim + local addone = false + if self.numInputDims and input:dim()==(self.numInputDims+1) then + if dim then + dim = dim + 1 + elseif input:size(1) == 1 then + addone = true -- in case of minibatch of size 1. + end + end + self.output:set(dim and input:squeeze(dim) or input:squeeze()) + if addone then + local s = self.output:size():totable{} + table.insert(s, 1, 1) + self.output:set(self.output:view(torch.LongStorage(s))) + end + return self.output +end + +function Squeeze:updateGradInput(input, gradOutput) + assert(input and torch.isTensor(input), 'Squeeze only works on tensors') + assert(gradOutput and torch.isTensor(gradOutput), 'Squeeze only works on tensors') + assert(input:nElement() == gradOutput:nElement()) + self.gradInput:set(gradOutput:view(input:size())) + return self.gradInput +end |