aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/SpatialUpSamplingNearest.lua
blob: 362ae73a31e81be14d1ffdb7132773ff51f8eba1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
local SpatialUpSamplingNearest, parent = torch.class('nn.SpatialUpSamplingNearest', 'nn.Module')

--[[
Applies a 2D up-sampling over an input image composed of several input planes.

The upsampling is done using the simple nearest neighbor technique.

The Y and X dimensions are assumed to be the last 2 tensor dimensions.  For
instance, if the tensor is 4D, then dim 3 is the y dimension and dim 4 is the x.

owidth  = width*scale_factor
oheight  = height*scale_factor
--]]

function SpatialUpSamplingNearest:__init(scale)
   parent.__init(self)

   self.scale_factor = scale
   if self.scale_factor < 1 then
     error('scale_factor must be greater than 1')
   end
   if math.floor(self.scale_factor) ~= self.scale_factor then
     error('scale_factor must be integer')
   end
   self.inputSize = torch.LongStorage(4)
   self.outputSize = torch.LongStorage(4)
end

function SpatialUpSamplingNearest:updateOutput(input)
   if input:dim() ~= 4 and input:dim() ~= 3 then
     error('SpatialUpSamplingNearest only support 3D or 4D tensors')
   end
   -- Copy the input size
   local xdim = input:dim()
   local ydim = input:dim() - 1
   for i = 1, input:dim() do
     self.inputSize[i] = input:size(i)
     self.outputSize[i] = input:size(i)
   end
   self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
   self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
   input.THNN.SpatialUpSamplingNearest_updateOutput(
      input:cdata(),
      self.output:cdata(),
      self.scale_factor
   )
   return self.output
end

function SpatialUpSamplingNearest:updateGradInput(input, gradOutput)
   self.gradInput:resizeAs(input)
   input.THNN.SpatialUpSamplingNearest_updateGradInput(
      input:cdata(),
      gradOutput:cdata(),
      self.gradInput:cdata(),
      self.scale_factor
   )
   return self.gradInput
end