summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Threshold.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/Threshold.lua')
-rw-r--r--contrib/lua-torch/nn/Threshold.lua51
1 files changed, 51 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Threshold.lua b/contrib/lua-torch/nn/Threshold.lua
new file mode 100644
index 000000000..6fdd26408
--- /dev/null
+++ b/contrib/lua-torch/nn/Threshold.lua
@@ -0,0 +1,51 @@
+local Threshold, parent = torch.class('nn.Threshold','nn.Module')
+
+function Threshold:__init(th,v,ip)
+ parent.__init(self)
+ self.threshold = th or 1e-6
+ self.val = v or 0
+ if (th and type(th) ~= 'number') or (v and type(v) ~= 'number') then
+ error('nn.Threshold(threshold, value)')
+ end
+ -- default for inplace is false
+ self.inplace = ip or false
+ if (ip and type(ip) ~= 'boolean') then
+ error('in-place flag must be boolean')
+ end
+ self:validateParameters()
+end
+
+function Threshold:updateOutput(input)
+ self:validateParameters()
+ input.THNN.Threshold_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.threshold,
+ self.val,
+ self.inplace
+ )
+ return self.output
+end
+
+function Threshold:updateGradInput(input, gradOutput)
+ self:validateParameters()
+ input.THNN.Threshold_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.threshold,
+ self.val,
+ self.inplace
+ )
+ return self.gradInput
+end
+
+function Threshold:validateParameters()
+ self.inplace = self.inplace or false -- backwards compatibility pre inplace
+ if self.inplace then
+ if self.val > self.threshold then
+ error('in-place processing requires value (' .. self.val ..
+ ') not exceed threshold (' .. self.threshold .. ')')
+ end
+ end
+end