summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/SelectTable.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/SelectTable.lua')
-rw-r--r--contrib/lua-torch/nn/SelectTable.lua71
1 files changed, 71 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SelectTable.lua b/contrib/lua-torch/nn/SelectTable.lua
new file mode 100644
index 000000000..ef26f3507
--- /dev/null
+++ b/contrib/lua-torch/nn/SelectTable.lua
@@ -0,0 +1,71 @@
+local SelectTable, parent = torch.class('nn.SelectTable', 'nn.Module')
+
+function SelectTable:__init(index)
+ parent.__init(self)
+ self.index = index
+ self.gradInput = {}
+end
+
+function SelectTable:updateOutput(input)
+
+ -- handle negative indices
+ local index = self.index
+ if type(index) == "number" then
+ index = index < 0 and #input + index + 1 or index
+ end
+
+ assert(input[index], "index does not exist in the input table")
+ self.output = input[index]
+
+ return self.output
+end
+
+local function zeroTableCopy(t1, t2)
+ for k, v in pairs(t2) do
+ if (torch.type(v) == "table") then
+ t1[k] = zeroTableCopy(t1[k] or {}, t2[k])
+ elseif torch.isTensor(v) then
+ if not t1[k] then
+ t1[k] = v:clone():zero()
+ else
+ t1[k]:resizeAs(v)
+ t1[k]:zero()
+ end
+ else
+ t1[k] = nil
+ end
+ end
+ for k, v in pairs(t1) do
+ if not t2[k] then
+ t1[k] = nil
+ end
+ end
+ return t1
+end
+
+function SelectTable:updateGradInput(input, gradOutput)
+ -- make gradInput a zeroed copy of input
+ zeroTableCopy(self.gradInput, input)
+ -- handle negative indices
+ local index = self.index
+ if type(index) == "number" then
+ index = index < 0 and #input + index + 1 or index
+ end
+ -- copy into gradInput[index] (necessary for variable sized inputs)
+ assert(self.gradInput[index])
+ nn.utils.recursiveCopy(self.gradInput[index], gradOutput)
+
+ return self.gradInput
+end
+
+function SelectTable:type(type, tensorCache)
+ self.gradInput = {}
+ self.output = {}
+ return parent.type(self, type, tensorCache)
+end
+
+function SelectTable:__tostring__()
+ return torch.type(self) .. '(' .. self.index .. ')'
+end
+
+SelectTable.clearState = nn.Identity.clearState