diff options
Diffstat (limited to 'contrib/lua-torch/nn/SelectTable.lua')
-rw-r--r-- | contrib/lua-torch/nn/SelectTable.lua | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SelectTable.lua b/contrib/lua-torch/nn/SelectTable.lua new file mode 100644 index 000000000..ef26f3507 --- /dev/null +++ b/contrib/lua-torch/nn/SelectTable.lua @@ -0,0 +1,71 @@ +local SelectTable, parent = torch.class('nn.SelectTable', 'nn.Module') + +function SelectTable:__init(index) + parent.__init(self) + self.index = index + self.gradInput = {} +end + +function SelectTable:updateOutput(input) + + -- handle negative indices + local index = self.index + if type(index) == "number" then + index = index < 0 and #input + index + 1 or index + end + + assert(input[index], "index does not exist in the input table") + self.output = input[index] + + return self.output +end + +local function zeroTableCopy(t1, t2) + for k, v in pairs(t2) do + if (torch.type(v) == "table") then + t1[k] = zeroTableCopy(t1[k] or {}, t2[k]) + elseif torch.isTensor(v) then + if not t1[k] then + t1[k] = v:clone():zero() + else + t1[k]:resizeAs(v) + t1[k]:zero() + end + else + t1[k] = nil + end + end + for k, v in pairs(t1) do + if not t2[k] then + t1[k] = nil + end + end + return t1 +end + +function SelectTable:updateGradInput(input, gradOutput) + -- make gradInput a zeroed copy of input + zeroTableCopy(self.gradInput, input) + -- handle negative indices + local index = self.index + if type(index) == "number" then + index = index < 0 and #input + index + 1 or index + end + -- copy into gradInput[index] (necessary for variable sized inputs) + assert(self.gradInput[index]) + nn.utils.recursiveCopy(self.gradInput[index], gradOutput) + + return self.gradInput +end + +function SelectTable:type(type, tensorCache) + self.gradInput = {} + self.output = {} + return parent.type(self, type, tensorCache) +end + +function SelectTable:__tostring__() + return torch.type(self) .. '(' .. self.index .. ')' +end + +SelectTable.clearState = nn.Identity.clearState |