diff options
Diffstat (limited to 'contrib/lua-torch/nn/VolumetricMaxPooling.lua')
-rw-r--r-- | contrib/lua-torch/nn/VolumetricMaxPooling.lua | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/VolumetricMaxPooling.lua b/contrib/lua-torch/nn/VolumetricMaxPooling.lua new file mode 100644 index 000000000..e25c5b31c --- /dev/null +++ b/contrib/lua-torch/nn/VolumetricMaxPooling.lua @@ -0,0 +1,102 @@ +local VolumetricMaxPooling, parent = torch.class('nn.VolumetricMaxPooling', 'nn.Module') + +VolumetricMaxPooling.__version = 2 + +function VolumetricMaxPooling:__init(kT, kW, kH, dT, dW, dH, padT, padW, padH) + parent.__init(self) + + dT = dT or kT + dW = dW or kW + dH = dH or kH + + self.kT = kT + self.kH = kH + self.kW = kW + self.dT = dT + self.dW = dW + self.dH = dH + + self.padT = padT or 0 + self.padW = padW or 0 + self.padH = padH or 0 + + + self.ceil_mode = false + self.indices = torch.LongTensor() +end + +function VolumetricMaxPooling:ceil() + self.ceil_mode = true + return self +end + +function VolumetricMaxPooling:floor() + self.ceil_mode = false + return self +end + +function VolumetricMaxPooling:updateOutput(input) + local dims = input:dim() + self.itime = input:size(dims-2) + self.iheight = input:size(dims-1) + self.iwidth = input:size(dims) + + self.indices = self.indices or torch.LongTensor() + if torch.typename(input):find('torch%.Cuda.*Tensor') then + self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices + else + self.indices = self.indices:long() + end + input.THNN.VolumetricMaxPooling_updateOutput( + input:cdata(), + self.output:cdata(), + self.indices:cdata(), + self.kT, self.kW, self.kH, + self.dT, self.dW, self.dH, + self.padT, self.padW, self.padH, + self.ceil_mode + ) + return self.output +end + +function VolumetricMaxPooling:updateGradInput(input, gradOutput) + input.THNN.VolumetricMaxPooling_updateGradInput( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.indices:cdata(), + self.kT, self.kW, self.kH, + self.dT, self.dW, self.dH, + self.padT, self.padW, self.padH, + self.ceil_mode + ) + return self.gradInput +end + +function VolumetricMaxPooling:empty() + self:clearState() +end + +function VolumetricMaxPooling:clearState() + if self.indices then self.indices:set() end + return parent.clearState(self) +end + +function VolumetricMaxPooling:read(file, version) + parent.read(self, file) + if version < 2 then + self.ceil_mode = false + end +end + +function VolumetricMaxPooling:__tostring__() + local s = string.format('%s(%dx%dx%d, %d,%d,%d', torch.type(self), + self.kT, self.kW, self.kH, self.dT, self.dW, self.dH) + if (self.padT or self.padW or self.padH) and + (self.padT ~= 0 or self.padW ~= 0 or self.padH ~= 0) then + s = s .. ', ' .. self.padT.. ',' .. self.padW .. ','.. self.padH + end + s = s .. ')' + + return s +end |