diff options
Diffstat (limited to 'contrib/lua-torch/nn/NarrowTable.lua')
-rw-r--r-- | contrib/lua-torch/nn/NarrowTable.lua | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/NarrowTable.lua b/contrib/lua-torch/nn/NarrowTable.lua new file mode 100644 index 000000000..17429f3b1 --- /dev/null +++ b/contrib/lua-torch/nn/NarrowTable.lua @@ -0,0 +1,43 @@ +local NarrowTable, parent = torch.class('nn.NarrowTable', 'nn.Module') + +function NarrowTable:__init(offset, length) + parent.__init(self) + self.offset = offset + self.length = length or 1 + if not offset then + error('nn.NarrowTable(offset, length)') + end + + self.output = {} + self.gradInput = {} +end + +function NarrowTable:updateOutput(input) + for k,v in ipairs(self.output) do self.output[k] = nil end + for i=1,self.length do + self.output[i] = input[self.offset+i-1] + end + return self.output +end + +function NarrowTable:updateGradInput(input, gradOutput) + for i=1,#gradOutput do + self.gradInput[self.offset+i-1] = gradOutput[i] + end + for i=1,#input do + if (i < self.offset) or (i >= self.offset + self.length) then + self.gradInput[i] = nn.utils.recursiveResizeAs(self.gradInput[i], input[i]) + nn.utils.recursiveFill(self.gradInput[i], 0) + end + end + for i=#input+1,#self.gradInput do self.gradInput[i] = nil end + return self.gradInput +end + +function NarrowTable:type(type, tensorCache) + self.output = {} + self.gradInput = {} + return parent.type(self, type, tensorCache) +end + +NarrowTable.clearState = nn.Identity.clearState |