aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Squeeze.lua
blob: 7d204a19d2389976e12bb31d447e8dba8e50cf2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
local Squeeze, parent = torch.class('nn.Squeeze', 'nn.Module')

function Squeeze:__init(dim, numInputDims)
    parent.__init(self)
    self.dim = dim
    self:setNumInputDims(numInputDims)
end

function Squeeze:setNumInputDims(numInputDims)
   self.numInputDims = numInputDims
   return self
end

function Squeeze:updateOutput(input)
    assert(input and torch.isTensor(input), 'Squeeze only works on tensors')
    local dim    = self.dim
    local addone = false
    if self.numInputDims and input:dim()==(self.numInputDims+1) then
        if dim then
            dim = dim + 1
        elseif input:size(1) == 1 then
            addone = true -- in case of minibatch of size 1.
        end
    end
    self.output:set(dim and input:squeeze(dim) or input:squeeze())
    if addone then
        local s = self.output:size():totable{}
        table.insert(s, 1, 1)
        self.output:set(self.output:view(torch.LongStorage(s)))
    end
    return self.output
end

function Squeeze:updateGradInput(input, gradOutput)
    assert(input and torch.isTensor(input), 'Squeeze only works on tensors')
    assert(gradOutput and torch.isTensor(gradOutput), 'Squeeze only works on tensors')
    assert(input:nElement() == gradOutput:nElement())
    self.gradInput:set(gradOutput:view(input:size()))
    return self.gradInput
end