aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/VolumetricFractionalMaxPooling.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/VolumetricFractionalMaxPooling.lua')
-rw-r--r--contrib/lua-torch/nn/VolumetricFractionalMaxPooling.lua175
1 files changed, 175 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/VolumetricFractionalMaxPooling.lua b/contrib/lua-torch/nn/VolumetricFractionalMaxPooling.lua
new file mode 100644
index 000000000..f5ff58cf0
--- /dev/null
+++ b/contrib/lua-torch/nn/VolumetricFractionalMaxPooling.lua
@@ -0,0 +1,175 @@
+local VolumetricFractionalMaxPooling, parent =
+ torch.class('nn.VolumetricFractionalMaxPooling', 'nn.Module')
+
+-- Usage:
+-- nn.VolumetricFractionalMaxPooling(poolSizeT, poolSizeW, poolSizeH, outT, outW, outH)
+-- the output should be the exact size (outT x outH x outW)
+-- nn.VolumetricFractionalMaxPooling(poolSizeT, poolSizeW, poolSizeH, ratioT, ratioW, ratioH)
+-- the output should be the size (floor(inT x ratioT) x floor(inH x ratioH) x floor(inW x ratioW))
+-- ratios are numbers between (0, 1) exclusive
+function VolumetricFractionalMaxPooling:__init(poolSizeT, poolSizeW, poolSizeH, arg1, arg2, arg3)
+ parent.__init(self)
+ assert(poolSizeT >= 2)
+ assert(poolSizeW >= 2)
+ assert(poolSizeH >= 2)
+
+ -- Pool size (how wide the pooling for each output unit is)
+ self.poolSizeT = poolSizeT
+ self.poolSizeW = poolSizeW
+ self.poolSizeH = poolSizeH
+
+ -- Random samples are drawn for all
+ -- batch * plane * (time, height, width; i.e., 3) points. This determines
+ -- the 3d "pseudorandom" overlapping pooling regions for each
+ -- (batch element x input plane). A new set of random samples is
+ -- drawn every updateOutput call, unless we disable it via
+ -- :fixPoolingRegions().
+ self.randomSamples = nil
+
+ -- Flag to disable re-generation of random samples for producing
+ -- a new pooling. For testing purposes
+ self.newRandomPool = false
+
+ if arg1 >= 1 and arg2 >= 1 and arg3 >= 1 then
+ -- Desired output size: the input tensor will determine the reduction
+ -- ratio
+ self.outT = arg1
+ self.outW = arg2
+ self.outH = arg3
+ else
+ -- Reduction ratio specified per each input
+ -- This is the reduction ratio that we use
+ self.ratioT = arg1
+ self.ratioW = arg2
+ self.ratioH = arg3
+
+ -- The reduction ratio must be between 0 and 1
+ assert(self.ratioT > 0 and self.ratioT < 1)
+ assert(self.ratioW > 0 and self.ratioW < 1)
+ assert(self.ratioH > 0 and self.ratioH < 1)
+ end
+end
+
+function VolumetricFractionalMaxPooling:getBufferSize_(input)
+ local batchSize = 0
+ local planeSize = 0
+
+ if input:nDimension() == 4 then
+ batchSize = 1
+ planeSize = input:size(1)
+ elseif input:nDimension() == 5 then
+ batchSize = input:size(1)
+ planeSize = input:size(2)
+ else
+ error('input must be dim 4 or 5')
+ end
+
+ return torch.LongStorage({batchSize, planeSize, 3})
+end
+
+function VolumetricFractionalMaxPooling:initSampleBuffer_(input)
+ local sampleBufferSize = self:getBufferSize_(input)
+
+ if self.randomSamples == nil then
+ self.randomSamples = input.new():resize(sampleBufferSize):uniform()
+ elseif (self.randomSamples:size(1) ~= sampleBufferSize[1] or
+ self.randomSamples:size(2) ~= sampleBufferSize[2]) then
+ self.randomSamples:resize(sampleBufferSize):uniform()
+ else
+ if not self.newRandomPool then
+ -- Create new pooling windows, since this is a subsequent call
+ self.randomSamples:uniform()
+ end
+ end
+end
+
+function VolumetricFractionalMaxPooling:getOutputSizes_(input)
+ local outT = self.outT
+ local outW = self.outW
+ local outH = self.outH
+ if self.ratioW ~= nil and self.ratioH ~= nil then
+ if input:nDimension() == 5 then
+ outT = math.floor(input:size(5) * self.ratioT)
+ outW = math.floor(input:size(4) * self.ratioW)
+ outH = math.floor(input:size(3) * self.ratioH)
+ elseif input:nDimension() == 4 then
+ outT = math.floor(input:size(4) * self.ratioT)
+ outW = math.floor(input:size(3) * self.ratioW)
+ outH = math.floor(input:size(2) * self.ratioH)
+ else
+ error('input must be dim 4 or 5')
+ end
+
+ -- Neither can be smaller than 1
+ assert(outT > 0, 'reduction ratio or input time too small')
+ assert(outW > 0, 'reduction ratio or input width too small')
+ assert(outH > 0, 'reduction ratio or input height too small')
+ else
+ assert(outT ~= nil and outW ~= nil and outH ~= nil)
+ end
+
+ return outT, outW, outH
+end
+
+-- Call this to turn off regeneration of random pooling regions each
+-- updateOutput call.
+function VolumetricFractionalMaxPooling:fixPoolingRegions(val)
+ if val == nil then
+ val = true
+ end
+
+ self.newRandomPool = val
+ return self
+end
+
+function VolumetricFractionalMaxPooling:updateOutput(input)
+ self.indices = self.indices or torch.LongTensor()
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
+ else
+ self.indices = self.indices:long()
+ end
+ self:initSampleBuffer_(input)
+ local outT, outW, outH = self:getOutputSizes_(input)
+
+ input.THNN.VolumetricFractionalMaxPooling_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ outT, outW, outH, self.poolSizeT, self.poolSizeW, self.poolSizeH,
+ self.indices:cdata(), self.randomSamples:cdata())
+ return self.output
+end
+
+function VolumetricFractionalMaxPooling:updateGradInput(input, gradOutput)
+ assert(self.randomSamples ~= nil,
+ 'must call updateOutput/forward first')
+
+ local outT, outW, outH = self:getOutputSizes_(input)
+
+ input.THNN.VolumetricFractionalMaxPooling_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ outT, outW, outH, self.poolSizeT, self.poolSizeW, self.poolSizeH,
+ self.indices:cdata())
+ return self.gradInput
+end
+
+-- backward compat
+function VolumetricFractionalMaxPooling:empty()
+ self:clearState()
+end
+
+function VolumetricFractionalMaxPooling:clearState()
+ self.indices = nil
+ self.randomSamples = nil
+ return parent.clearState(self)
+end
+
+function VolumetricFractionalMaxPooling:__tostring__()
+ return string.format('%s(%dx%dx%d, %d,%d,%d)', torch.type(self),
+ self.outT and self.outT or self.ratioT,
+ self.outW and self.outW or self.ratioW,
+ self.outH and self.outH or self.ratioH,
+ self.poolSizeT, self.poolSizeW, self.poolSizeH)
+end