summaryrefslogtreecommitdiffstats
path: root/contrib/torch/nn/LeakyReLU.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/torch/nn/LeakyReLU.lua')
-rw-r--r--contrib/torch/nn/LeakyReLU.lua41
1 files changed, 41 insertions, 0 deletions
diff --git a/contrib/torch/nn/LeakyReLU.lua b/contrib/torch/nn/LeakyReLU.lua
new file mode 100644
index 000000000..56b7f2542
--- /dev/null
+++ b/contrib/torch/nn/LeakyReLU.lua
@@ -0,0 +1,41 @@
+local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
+
+function LeakyReLU:__init(negval,ip)
+ parent.__init(self)
+ if type(negval) == 'boolean' then
+ local ip = negval
+ self.negval = 1/100
+ else
+ self.negval = negval or (1/100)
+ end
+ -- default for inplace is false
+ self.inplace = ip or false
+ if self.negval < 0 then
+ self.inplace = false
+ end
+end
+
+function LeakyReLU:updateOutput(input)
+ input.THNN.LeakyReLU_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.negval,
+ self.inplace
+ )
+ return self.output
+end
+
+function LeakyReLU:updateGradInput(input, gradOutput)
+ input.THNN.LeakyReLU_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.negval,
+ self.inplace
+ )
+ return self.gradInput
+end
+
+function LeakyReLU:__tostring__()
+ return torch.type(self) .. string.format('(%g)', self.negval)
+end