aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Collapse.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/Collapse.lua')
-rw-r--r--contrib/lua-torch/nn/Collapse.lua30
1 files changed, 30 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Collapse.lua b/contrib/lua-torch/nn/Collapse.lua
new file mode 100644
index 000000000..a088608ca
--- /dev/null
+++ b/contrib/lua-torch/nn/Collapse.lua
@@ -0,0 +1,30 @@
+local Collapse, parent = torch.class('nn.Collapse', 'nn.Module')
+
+-- collapses non-batch dims
+function Collapse:__init(nInputDim)
+ parent.__init(self)
+ self.nInputDim = nInputDim
+end
+
+function Collapse:updateOutput(input)
+ if not input:isContiguous() then
+ self._input = self._input or input.new()
+ self._input:resize(input:size()):copy(input)
+ input = self._input
+ end
+ if input:dim() > self.nInputDim then
+ self.output:view(input,input:size(1),-1)
+ else
+ self.output:view(input,-1)
+ end
+ return self.output
+end
+
+function Collapse:updateGradInput(input, gradOutput)
+ self.gradInput:view(gradOutput, input:size())
+ return self.gradInput
+end
+
+function Collapse:clearState()
+ self._input = nil
+end