summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/AddConstant.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/AddConstant.lua')
-rw-r--r--contrib/lua-torch/nn/AddConstant.lua50
1 files changed, 50 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/AddConstant.lua b/contrib/lua-torch/nn/AddConstant.lua
new file mode 100644
index 000000000..b686d719c
--- /dev/null
+++ b/contrib/lua-torch/nn/AddConstant.lua
@@ -0,0 +1,50 @@
+local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module')
+
+function AddConstant:__init(constant_scalar,ip)
+ parent.__init(self)
+ self.constant_scalar = constant_scalar
+
+ -- default for inplace is false
+ self.inplace = ip or false
+ if (ip and type(ip) ~= 'boolean') then
+ error('in-place flag must be boolean')
+ end
+end
+
+function AddConstant:updateOutput(input)
+ assert(type(self.constant_scalar) == 'number' or
+ (torch.isTensor(self.constant_scalar) and input:nDimension() <= 2 and
+ input:size(input:nDimension()) == self.constant_scalar:size(1)),
+ 'input is not scalar or doesn\'t match with the dimension of constant!')
+ local tmp
+ if torch.isTensor(self.constant_scalar) and input:nDimension() == 2 then
+ local nOutput = self.constant_scalar:size(1)
+ tmp = self.constant_scalar.new()
+ tmp:resize(1,nOutput)
+ tmp:copy(self.constant_scalar)
+ tmp = tmp:expand(input:size(1),nOutput)
+ else
+ tmp = self.constant_scalar
+ end
+ if self.inplace then
+ input:add(tmp)
+ self.output:set(input)
+ else
+ self.output:resizeAs(input)
+ self.output:copy(input)
+ self.output:add(tmp)
+ end
+ return self.output
+end
+
+function AddConstant:updateGradInput(input, gradOutput)
+ if self.inplace then
+ self.gradInput:set(gradOutput)
+ -- restore previous input value
+ input:add(-self.constant_scalar)
+ else
+ self.gradInput:resizeAs(gradOutput)
+ self.gradInput:copy(gradOutput)
+ end
+ return self.gradInput
+end