123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- local CReLU, parent = torch.class('nn.CReLU', 'nn.Sequential')
-
- -- Implements the CReLU activation function as described by
- -- W. Shang et al. in "Understanding and Improving Convolutional Neural Networks
- -- via Concatenated Rectified Linear Units"
- function CReLU:__init(nInputDims, inplace)
- parent.__init(self)
- self.nInputDims = nInputDims
- self.inplace = inplace or false
-
- local concatTable = nn.ConcatTable()
- concatTable:add(nn.Identity())
- concatTable:add(nn.MulConstant(-1))
- self:add(concatTable)
- self:add(nn.JoinTable(2))
- self:add(nn.ReLU(self.inplace))
- end
-
- function CReLU:updateOutput(input)
- local input_
- local batched = input:dim() == (self.nInputDims + 1)
- if not batched then
- input_ = input:view(1, -1)
- else
- input_ = input:view(input:size(1), -1)
- end
- parent.updateOutput(self, input_)
- local osize = input:size()
- if not batched then
- osize[1] = osize[1] * 2
- else
- osize[2] = osize[2] * 2
- end
- self.output:resize(osize)
- return self.output
- end
-
- function CReLU:backward(input, gradOutput)
- return self:updateGradInput(input, gradOutput)
- end
-
- function CReLU:updateGradInput(input, gradOutput)
- local batched = input:dim() == (self.nInputDims + 1)
- if not batched then
- parent.updateGradInput(self, input:view(1, -1), gradOutput:view(1, -1))
- else
- parent.updateGradInput(self, input:view(input:size(1), -1),
- gradOutput:view(input:size(1), -1))
- end
-
- self.gradInput:resizeAs(input)
- return self.gradInput
- end
-
- function CReLU:__tostring__()
- return "CReLU()"
- end
|