summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/SpatialCrossMapLRN.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/SpatialCrossMapLRN.lua')
-rw-r--r--contrib/lua-torch/nn/SpatialCrossMapLRN.lua153
1 files changed, 153 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SpatialCrossMapLRN.lua b/contrib/lua-torch/nn/SpatialCrossMapLRN.lua
new file mode 100644
index 000000000..088eb07f0
--- /dev/null
+++ b/contrib/lua-torch/nn/SpatialCrossMapLRN.lua
@@ -0,0 +1,153 @@
+local SpatialCrossMapLRN, parent = torch.class('nn.SpatialCrossMapLRN', 'nn.Module')
+
+function SpatialCrossMapLRN:__init(size, alpha, beta, k)
+ parent.__init(self)
+
+ self.size = size
+ self.alpha = alpha or 0.0001
+ self.beta = beta or 0.75
+ self.k = k or 1
+end
+
+function SpatialCrossMapLRN:updateOutput(input)
+ assert(input:dim() == 3 or input:dim() == 4,
+ 'Input must be 3D or 4D')
+
+ self.scale = self.scale or input.new()
+
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ input.THNN.SpatialCrossMapLRN_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.scale:cdata(),
+ self.size,
+ self.alpha,
+ self.beta,
+ self.k
+ )
+ else
+ local isBatch = true
+ if input:dim() == 3 then
+ input = nn.utils.addSingletonDimension(input)
+ isBatch = false
+ end
+
+ local batchSize = input:size(1)
+ local channels = input:size(2)
+ local inputHeight = input:size(3)
+ local inputWidth = input:size(4)
+
+ self.output:resizeAs(input)
+ self.scale:resizeAs(input)
+
+ -- use output storage as temporary buffer
+ local inputSquare = self.output
+ inputSquare:pow(input, 2)
+
+ local prePad = (self.size - 1)/2 + 1
+ local prePadCrop = prePad > channels and channels or prePad
+
+ local scaleFirst = self.scale:select(2,1)
+ scaleFirst:zero()
+ -- compute first feature map normalization
+ for c = 1, prePadCrop do
+ scaleFirst:add(inputSquare:select(2, c))
+ end
+
+ -- reuse computations for next feature maps normalization
+ -- by adding the next feature map and removing the previous
+ for c = 2, channels do
+ local scalePrevious = self.scale:select(2, c -1)
+ local scaleCurrent = self.scale:select(2, c)
+ scaleCurrent:copy(scalePrevious)
+ if c < channels - prePad + 2 then
+ local squareNext = inputSquare:select(2, c + prePad - 1)
+ scaleCurrent:add(1, squareNext)
+ end
+ if c > prePad then
+ local squarePrevious = inputSquare:select(2, c - prePad )
+ scaleCurrent:add(-1, squarePrevious)
+ end
+ end
+
+ self.scale:mul(self.alpha/self.size):add(self.k)
+
+ self.output:pow(self.scale,-self.beta)
+ self.output:cmul(input)
+
+ if not isBatch then
+ self.output = self.output[1]
+ end
+ end
+
+ return self.output
+end
+
+function SpatialCrossMapLRN:updateGradInput(input, gradOutput)
+ assert(input:dim() == 3 or input:dim() == 4,
+ 'Input must be 3D or 4D')
+
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ input.THNN.SpatialCrossMapLRN_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.scale:cdata(),
+ self.output:cdata(),
+ self.size,
+ self.alpha,
+ self.beta,
+ self.k
+ )
+ else
+ local isBatch = true
+ if input:dim() == 3 then
+ input = nn.utils.addSingletonDimension(input)
+ gradOutput = nn.utils.addSingletonDimension(gradOutput)
+ self.output = nn.utils.addSingletonDimension(self.output)
+ isBatch = false
+ end
+
+ local batchSize = input:size(1)
+ local channels = input:size(2)
+ local inputHeight = input:size(3)
+ local inputWidth = input:size(4)
+
+ self.paddedRatio = self.paddedRatio or input.new()
+ self.accumRatio = self.accumRatio or input.new()
+ self.paddedRatio:resize(channels + self.size - 1, inputHeight, inputWidth)
+ self.accumRatio:resize(inputHeight,inputWidth)
+
+ local cacheRatioValue = 2*self.alpha*self.beta/self.size
+ local inversePrePad = self.size - (self.size - 1) / 2
+
+ self.gradInput:resizeAs(input)
+ self.gradInput:pow(self.scale,-self.beta):cmul(gradOutput)
+
+ self.paddedRatio:zero()
+ local paddedRatioCenter = self.paddedRatio:narrow(1, inversePrePad, channels)
+ for n = 1, batchSize do
+ paddedRatioCenter:cmul(gradOutput[n],self.output[n])
+ paddedRatioCenter:cdiv(self.scale[n])
+ self.accumRatio:sum(self.paddedRatio:narrow(1,1,self.size-1), 1)
+ for c = 1, channels do
+ self.accumRatio:add(self.paddedRatio[c+self.size-1])
+ self.gradInput[n][c]:addcmul(-cacheRatioValue, input[n][c], self.accumRatio)
+ self.accumRatio:add(-1, self.paddedRatio[c])
+ end
+ end
+
+ if not isBatch then
+ self.gradInput = self.gradInput[1]
+ self.output = self.output[1]
+ end
+ end
+
+ return self.gradInput
+end
+
+
+function SpatialCrossMapLRN:clearState()
+ nn.utils.clear(self, 'scale', 'paddedRatio', 'accumRatio')
+ return parent.clearState(self)
+end