diff options
Diffstat (limited to 'contrib/lua-torch/nn/Narrow.lua')
-rw-r--r-- | contrib/lua-torch/nn/Narrow.lua | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Narrow.lua b/contrib/lua-torch/nn/Narrow.lua new file mode 100644 index 000000000..a6ebaa321 --- /dev/null +++ b/contrib/lua-torch/nn/Narrow.lua @@ -0,0 +1,45 @@ +local Narrow, parent = torch.class('nn.Narrow', 'nn.Module') + +function Narrow:__init(dimension,offset,length) + parent.__init(self) + self.dimension=dimension + self.index=offset + self.length=length or 1 + if not dimension or not offset then + error('nn.Narrow(dimension, offset, length)') + end +end + +function Narrow:updateOutput(input) + local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension + local length = self.length + if length < 0 then + length = input:size(dim) - self.index + self.length + 2 + end + local index = self.index + if self.index < 0 then + index = 1 + length = input:size(dim) - length + end + local output=input:narrow(dim, index, length) + self.output = self.output:typeAs(output) + self.output:resizeAs(output):copy(output) + return self.output +end + +function Narrow:updateGradInput(input, gradOutput) + local dim = self.dimension < 0 and input:dim() + self.dimension + 1 or self.dimension + local length = self.length + if length < 0 then + length = input:size(dim) - self.index + self.length + 2 + end + local index = self.index + if self.index < 0 then + index = 1 + length = input:size(dim) - length + end + self.gradInput = self.gradInput:typeAs(input) + self.gradInput:resizeAs(input):zero() + self.gradInput:narrow(dim,index,length):copy(gradOutput) + return self.gradInput +end |