diff options
Diffstat (limited to 'contrib/lua-torch/nn/SpatialCrossMapLRN.lua')
-rw-r--r-- | contrib/lua-torch/nn/SpatialCrossMapLRN.lua | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SpatialCrossMapLRN.lua b/contrib/lua-torch/nn/SpatialCrossMapLRN.lua new file mode 100644 index 000000000..088eb07f0 --- /dev/null +++ b/contrib/lua-torch/nn/SpatialCrossMapLRN.lua @@ -0,0 +1,153 @@ +local SpatialCrossMapLRN, parent = torch.class('nn.SpatialCrossMapLRN', 'nn.Module') + +function SpatialCrossMapLRN:__init(size, alpha, beta, k) + parent.__init(self) + + self.size = size + self.alpha = alpha or 0.0001 + self.beta = beta or 0.75 + self.k = k or 1 +end + +function SpatialCrossMapLRN:updateOutput(input) + assert(input:dim() == 3 or input:dim() == 4, + 'Input must be 3D or 4D') + + self.scale = self.scale or input.new() + + if torch.typename(input):find('torch%.Cuda.*Tensor') then + input.THNN.SpatialCrossMapLRN_updateOutput( + input:cdata(), + self.output:cdata(), + self.scale:cdata(), + self.size, + self.alpha, + self.beta, + self.k + ) + else + local isBatch = true + if input:dim() == 3 then + input = nn.utils.addSingletonDimension(input) + isBatch = false + end + + local batchSize = input:size(1) + local channels = input:size(2) + local inputHeight = input:size(3) + local inputWidth = input:size(4) + + self.output:resizeAs(input) + self.scale:resizeAs(input) + + -- use output storage as temporary buffer + local inputSquare = self.output + inputSquare:pow(input, 2) + + local prePad = (self.size - 1)/2 + 1 + local prePadCrop = prePad > channels and channels or prePad + + local scaleFirst = self.scale:select(2,1) + scaleFirst:zero() + -- compute first feature map normalization + for c = 1, prePadCrop do + scaleFirst:add(inputSquare:select(2, c)) + end + + -- reuse computations for next feature maps normalization + -- by adding the next feature map and removing the previous + for c = 2, channels do + local scalePrevious = self.scale:select(2, c -1) + local scaleCurrent = self.scale:select(2, c) + scaleCurrent:copy(scalePrevious) + if c < channels - prePad + 2 then + local squareNext = inputSquare:select(2, c + prePad - 1) + scaleCurrent:add(1, squareNext) + end + if c > prePad then + local squarePrevious = inputSquare:select(2, c - prePad ) + scaleCurrent:add(-1, squarePrevious) + end + end + + self.scale:mul(self.alpha/self.size):add(self.k) + + self.output:pow(self.scale,-self.beta) + self.output:cmul(input) + + if not isBatch then + self.output = self.output[1] + end + end + + return self.output +end + +function SpatialCrossMapLRN:updateGradInput(input, gradOutput) + assert(input:dim() == 3 or input:dim() == 4, + 'Input must be 3D or 4D') + + if torch.typename(input):find('torch%.Cuda.*Tensor') then + input.THNN.SpatialCrossMapLRN_updateGradInput( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.scale:cdata(), + self.output:cdata(), + self.size, + self.alpha, + self.beta, + self.k + ) + else + local isBatch = true + if input:dim() == 3 then + input = nn.utils.addSingletonDimension(input) + gradOutput = nn.utils.addSingletonDimension(gradOutput) + self.output = nn.utils.addSingletonDimension(self.output) + isBatch = false + end + + local batchSize = input:size(1) + local channels = input:size(2) + local inputHeight = input:size(3) + local inputWidth = input:size(4) + + self.paddedRatio = self.paddedRatio or input.new() + self.accumRatio = self.accumRatio or input.new() + self.paddedRatio:resize(channels + self.size - 1, inputHeight, inputWidth) + self.accumRatio:resize(inputHeight,inputWidth) + + local cacheRatioValue = 2*self.alpha*self.beta/self.size + local inversePrePad = self.size - (self.size - 1) / 2 + + self.gradInput:resizeAs(input) + self.gradInput:pow(self.scale,-self.beta):cmul(gradOutput) + + self.paddedRatio:zero() + local paddedRatioCenter = self.paddedRatio:narrow(1, inversePrePad, channels) + for n = 1, batchSize do + paddedRatioCenter:cmul(gradOutput[n],self.output[n]) + paddedRatioCenter:cdiv(self.scale[n]) + self.accumRatio:sum(self.paddedRatio:narrow(1,1,self.size-1), 1) + for c = 1, channels do + self.accumRatio:add(self.paddedRatio[c+self.size-1]) + self.gradInput[n][c]:addcmul(-cacheRatioValue, input[n][c], self.accumRatio) + self.accumRatio:add(-1, self.paddedRatio[c]) + end + end + + if not isBatch then + self.gradInput = self.gradInput[1] + self.output = self.output[1] + end + end + + return self.gradInput +end + + +function SpatialCrossMapLRN:clearState() + nn.utils.clear(self, 'scale', 'paddedRatio', 'accumRatio') + return parent.clearState(self) +end |