diff options
Diffstat (limited to 'contrib/lua-torch/nn/CReLU.lua')
-rw-r--r-- | contrib/lua-torch/nn/CReLU.lua | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/CReLU.lua b/contrib/lua-torch/nn/CReLU.lua new file mode 100644 index 000000000..8da6e7974 --- /dev/null +++ b/contrib/lua-torch/nn/CReLU.lua @@ -0,0 +1,57 @@ +local CReLU, parent = torch.class('nn.CReLU', 'nn.Sequential') + +-- Implements the CReLU activation function as described by +-- W. Shang et al. in "Understanding and Improving Convolutional Neural Networks +-- via Concatenated Rectified Linear Units" +function CReLU:__init(nInputDims, inplace) + parent.__init(self) + self.nInputDims = nInputDims + self.inplace = inplace or false + + local concatTable = nn.ConcatTable() + concatTable:add(nn.Identity()) + concatTable:add(nn.MulConstant(-1)) + self:add(concatTable) + self:add(nn.JoinTable(2)) + self:add(nn.ReLU(self.inplace)) +end + +function CReLU:updateOutput(input) + local input_ + local batched = input:dim() == (self.nInputDims + 1) + if not batched then + input_ = input:view(1, -1) + else + input_ = input:view(input:size(1), -1) + end + parent.updateOutput(self, input_) + local osize = input:size() + if not batched then + osize[1] = osize[1] * 2 + else + osize[2] = osize[2] * 2 + end + self.output:resize(osize) + return self.output +end + +function CReLU:backward(input, gradOutput) + return self:updateGradInput(input, gradOutput) +end + +function CReLU:updateGradInput(input, gradOutput) + local batched = input:dim() == (self.nInputDims + 1) + if not batched then + parent.updateGradInput(self, input:view(1, -1), gradOutput:view(1, -1)) + else + parent.updateGradInput(self, input:view(input:size(1), -1), + gradOutput:view(input:size(1), -1)) + end + + self.gradInput:resizeAs(input) + return self.gradInput +end + +function CReLU:__tostring__() + return "CReLU()" +end |