summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/SpatialDilatedConvolution.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/SpatialDilatedConvolution.lua')
-rw-r--r--contrib/lua-torch/nn/SpatialDilatedConvolution.lua80
1 files changed, 80 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SpatialDilatedConvolution.lua b/contrib/lua-torch/nn/SpatialDilatedConvolution.lua
new file mode 100644
index 000000000..a0590c7e9
--- /dev/null
+++ b/contrib/lua-torch/nn/SpatialDilatedConvolution.lua
@@ -0,0 +1,80 @@
+local THNN = require 'nn.THNN'
+local SpatialDilatedConvolution, parent = torch.class('nn.SpatialDilatedConvolution', 'nn.SpatialConvolution')
+
+function SpatialDilatedConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
+ parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+
+ self.dilationW = dilationW or 1
+ self.dilationH = dilationH or 1
+end
+
+function SpatialDilatedConvolution:updateOutput(input)
+ self.finput = self.finput or self.weight.new()
+ self.fgradInput = self.fgradInput or self.weight.new()
+ input.THNN.SpatialDilatedConvolution_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.weight:cdata(),
+ THNN.optionalTensor(self.bias),
+ self.finput:cdata(),
+ self.fgradInput:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.dilationW, self.dilationH
+ )
+ return self.output
+end
+
+function SpatialDilatedConvolution:updateGradInput(input, gradOutput)
+ if self.gradInput then
+ self.fgradInput = self.fgradInput or self.weight.new()
+ input.THNN.SpatialDilatedConvolution_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.weight:cdata(),
+ self.finput:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.dilationW, self.dilationH
+ )
+ return self.gradInput
+ end
+end
+
+function SpatialDilatedConvolution:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
+ self.fgradInput = self.fgradInput or self.weight.new()
+ input.THNN.SpatialDilatedConvolution_accGradParameters(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradWeight:cdata(),
+ THNN.optionalTensor(self.gradBias),
+ self.finput:cdata(),
+ self.fgradInput:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.dilationW, self.dilationH,
+ scale
+ )
+end
+
+function SpatialDilatedConvolution:__tostring__()
+ local s = string.format('%s(%d -> %d, %dx%d', torch.type(self),
+ self.nInputPlane, self.nOutputPlane, self.kW, self.kH)
+ if self.dW ~= 1 or self.dH ~= 1 or self.padW ~= 0 or self.padH ~= 0 then
+ s = s .. string.format(', %d,%d', self.dW, self.dH)
+ end
+ if (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then
+ s = s .. ', ' .. self.padW .. ',' .. self.padH
+ end
+ s = s .. ', ' .. self.dilationW .. ',' .. self.dilationH
+ if self.bias then
+ return s .. ')'
+ else
+ return s .. ') without bias'
+ end
+end