diff options
Diffstat (limited to 'contrib/lua-torch/nn/Transpose.lua')
-rw-r--r-- | contrib/lua-torch/nn/Transpose.lua | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Transpose.lua b/contrib/lua-torch/nn/Transpose.lua new file mode 100644 index 000000000..cceb2b643 --- /dev/null +++ b/contrib/lua-torch/nn/Transpose.lua @@ -0,0 +1,35 @@ +local Transpose, parent = torch.class('nn.Transpose', 'nn.Module') + +-- transpose dimensions: +-- n = nn.Transpose({1,4},{1,3}) +-- will transpose dims 1 and 4, then 1 and 3... + +function Transpose:__init(...) + parent.__init(self) + self.permutations = {...} + self.numInputDims = nil +end + +function Transpose:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self +end + +function Transpose:updateOutput(input) + local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0 + for _,perm in ipairs(self.permutations) do + input = input:transpose(perm[1]+offset,perm[2]+offset) + end + self.output:resizeAs(input):copy(input) + return self.output +end + +function Transpose:updateGradInput(input, gradOutput) + for i = #self.permutations,1,-1 do + local perm = self.permutations[i] + local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0 + gradOutput = gradOutput:transpose(perm[1]+offset,perm[2]+offset) + end + self.gradInput:resizeAs(gradOutput):copy(gradOutput) + return self.gradInput +end |