aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/SplitTable.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/SplitTable.lua')
-rw-r--r--contrib/lua-torch/nn/SplitTable.lua43
1 files changed, 43 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/SplitTable.lua b/contrib/lua-torch/nn/SplitTable.lua
new file mode 100644
index 000000000..7c4f968e6
--- /dev/null
+++ b/contrib/lua-torch/nn/SplitTable.lua
@@ -0,0 +1,43 @@
+local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module')
+
+function SplitTable:__init(dimension, nInputDims)
+ parent.__init(self)
+ self.dimension = dimension
+ self.nInputDims = nInputDims
+end
+
+function SplitTable:_getPositiveDimension(input)
+ local dimension = self.dimension
+ if dimension < 0 then
+ dimension = input:dim() + dimension + 1
+ elseif self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ return dimension
+end
+
+function SplitTable:updateOutput(input)
+ local dimension = self:_getPositiveDimension(input)
+ local slices = input:size(dimension)
+
+ local currentOutput= {}
+ for i=1,slices do
+ currentOutput[#currentOutput+1] = input:select(dimension,i)
+ end
+ self.output = currentOutput
+ return self.output
+end
+
+function SplitTable:updateGradInput(input, gradOutput)
+ local dimension = self:_getPositiveDimension(input)
+ local slices = input:size(dimension)
+ if self.gradInput then
+ self.gradInput:resizeAs(input)
+
+ for i=1,slices do
+ local currentGradInput = gradOutput[i];
+ self.gradInput:select(dimension,i):copy(currentGradInput)
+ end
+ end
+ return self.gradInput
+end