aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/nn/SelectTable.lua
blob: ef26f3507bdf298c006b8d65fba059b0dbf8c2dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
local SelectTable, parent = torch.class('nn.SelectTable', 'nn.Module')

function SelectTable:__init(index)
   parent.__init(self)
   self.index = index
   self.gradInput = {}
end

function SelectTable:updateOutput(input)

   -- handle negative indices
   local index = self.index
   if type(index) == "number" then
      index = index < 0 and #input + index + 1 or index
   end

   assert(input[index], "index does not exist in the input table")
   self.output = input[index]

   return self.output
end

local function zeroTableCopy(t1, t2)
   for k, v in pairs(t2) do
      if (torch.type(v) == "table") then
         t1[k] = zeroTableCopy(t1[k] or {}, t2[k])
      elseif torch.isTensor(v) then
         if not t1[k] then
            t1[k] = v:clone():zero()
         else
            t1[k]:resizeAs(v)
            t1[k]:zero()
         end
      else
        t1[k] = nil
      end
   end
   for k, v in pairs(t1) do
      if not t2[k] then
         t1[k] = nil
      end
   end
   return t1
end

function SelectTable:updateGradInput(input, gradOutput)
   -- make gradInput a zeroed copy of input
   zeroTableCopy(self.gradInput, input)
   -- handle negative indices
   local index = self.index
   if type(index) == "number" then
      index = index < 0 and #input + index + 1 or index
   end
   -- copy into gradInput[index] (necessary for variable sized inputs)
   assert(self.gradInput[index])
   nn.utils.recursiveCopy(self.gradInput[index], gradOutput)

   return self.gradInput
end

function SelectTable:type(type, tensorCache)
   self.gradInput = {}
   self.output = {}
   return parent.type(self, type, tensorCache)
end

function SelectTable:__tostring__()
  return torch.type(self) .. '(' .. self.index .. ')'
end

SelectTable.clearState = nn.Identity.clearState