aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/Convert.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/Convert.lua')
-rw-r--r--contrib/lua-torch/nn/Convert.lua245
1 files changed, 245 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Convert.lua b/contrib/lua-torch/nn/Convert.lua
new file mode 100644
index 000000000..855338dd6
--- /dev/null
+++ b/contrib/lua-torch/nn/Convert.lua
@@ -0,0 +1,245 @@
+------------------------------------------------------------------------
+--[ nn.Convert ]--
+-- Module to convert between different data formats
+-- nn.Convert('bchw', 'bf') or nn.Convert('chw', 'f')
+-- Automatically converts input to same type as self.output
+-- Simplest use is for automatic input type converions : nn.Convert()
+------------------------------------------------------------------------
+local _ = require 'moses'
+local Convert, parent = torch.class("nn.Convert", "nn.Container")
+
+function Convert:__init(inputShape, outputShape)
+ if outputShape and not inputShape then
+ error"Expecting non-nil arg 1 when arg 2 is provided"
+ end
+ inputShape = inputShape or 'b*'
+ outputShape = outputShape or inputShape
+ self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape)
+ self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape)
+ self.inputBatchDim = self.inputShape:find('b')
+ self.outputBatchDim = self.outputShape:find('b')
+ if self.inputShape == 'b*' or self.outputShape == 'b*' then
+ assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*')
+ self.nInputDim = -1
+ self.nOutputDim = -1
+ self.transposition = true
+ else
+ -- number of dims in batch mode
+ self.nInputDim = #self.inputShape
+ self.nOutputDim = #self.outputShape
+ -- is the outputShape just a transposition of the inputShape?
+ if self.nInputDim == self.nOutputDim then
+ self.transposition = true
+ for i=1,self.nInputDim do
+ if not self.outputShape:find(self.inputShape:sub(i,i)) then
+ self.transposition = false
+ break
+ end
+ end
+ end
+ end
+ parent.__init(self)
+end
+
+-- post-initialization
+function Convert:buildConverter(input)
+ if self.transposition then
+ self.converter = self:transpose(self.outputShape)
+ else
+ if (torch.type(self[self.outputShape]) ~= 'function') then
+ error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape))
+ end
+ self.converter = self[self.outputShape](self, input)
+ end
+ assert(torch.isTensor(self.output), "Expecting Tensor output")
+
+ self.converter:type(torch.type(self.output))
+
+ self.modules[1] = self.converter
+end
+
+function Convert:updateOutput(input)
+ assert(torch.isTensor(input), "expecting Tensor")
+ if not torch.isTypeOf(input, torch.type(self.output)) then
+ -- handle different input type
+ self._input = self._input or self.output.new()
+ self._input:resize(input:size()):copy(input)
+ input = self._input
+ end
+ self.batchMode = true
+ if input:dim() < self.nInputDim then
+ -- handle non-batch mode
+ local inputSize = input:size():totable()
+ table.insert(inputSize, self.inputBatchDim, 1)
+ self.__input = self.__input or input.new()
+ self.__input:set(input):resize(table.unpack(inputSize))
+ input = self.__input
+ self.batchMode = false
+ end
+ if not self.converter then
+ self:buildConverter(input)
+ end
+
+ self.output = self.converter:updateOutput(input)
+
+ if not self.batchMode then
+ local outputSize = self.output:size():totable()
+ table.remove(outputSize, self.outputBatchDim)
+ self.__output = self.__output or self.output.new()
+ self.__output:set(self.output):resize(table.unpack(outputSize))
+ self.output = self.__output
+ end
+ return self.output
+end
+
+function Convert:updateGradInput(input, gradOutput)
+ local input_ = input
+ input = self._input or input
+ if not self.batchMode then
+ input = self.__input
+ self.__gradOutput = self.__gradOutput or gradOutput.new()
+ self.__gradOutput:set(gradOutput):resize(self.converter.output:size())
+ gradOutput = self.__gradOutput
+ end
+
+ local gradInput = self.converter:updateGradInput(input, gradOutput)
+
+ if not self.batchMode then
+ self.__gradInput = self.__gradInput or gradInput.new()
+ self.__gradInput:set(gradInput):resize(input_:size())
+ gradInput = self.__gradInput
+ end
+ if self._input then
+ self._gradInput = self._gradInput or input.new()
+ self._gradInput:resize(input:size()):copy(gradInput)
+ self.gradInput = self._gradInput
+ else
+ self.gradInput = gradInput
+ end
+
+ return self.gradInput
+end
+
+function Convert:accGradParameters(input, gradOutput, scale)
+ input = self.batchMode and self.__input or self._input or input
+ gradOutput = self.batchMode and self.__gradOutput or gradOutput
+ self.converter:accGradParameters(input, gradOutput, scale)
+end
+
+function Convert:accUpdateGradParameters(input, gradOutput, lr)
+ input = self.batchMode and self.__input or self._input or input
+ gradOutput = self.batchMode and self.__gradOutput or gradOutput
+ self.converter:accUpdateGradParameters(input, gradOutput, lr)
+end
+
+-- batch feature
+function Convert:bf(input)
+ local b_pos = self:findAxis('b', self.inputShape)
+ local dim = #self.inputShape
+ if self.inputShape == 'bt' then
+ error"Conversion of shape bt to bf not supported: open an issue on github"
+ end
+ -- was b
+ if dim == 1 then
+ return nn.Reshape(1)
+ end
+ -- was b...
+ local modula
+ if b_pos ~= 1 then
+ modula = nn.Transpose({1, b_pos})
+ end
+ if dim > 2 then
+ local transpose = modula
+ local sampleSize = input:select(self:findAxis('b'),1):nElement()
+ local reshape = nn.Reshape(sampleSize)
+ if transpose then
+ modula = nn.Sequential()
+ modula:add(transpose)
+ modula:add(reshape)
+ else
+ modula = reshape
+ end
+ end
+ return modula or nn.Identity()
+end
+
+-- each example is a scalar; batch is a vector
+function Convert:b(input)
+ local b_pos = self:findAxis('b')
+ if self.inputShape == 'bt' or self.inputShape == 'tb' then
+ local t_pos = self:findAxis('t')
+ -- select first set of classes
+ return nn.Select(t_pos, 1)
+ elseif self.inputShape == 'bf' or self.inputShape == 'fb' then
+ -- this wont work as expected with size(f) > 1
+ local f_pos = self:findAxis('f')
+ if input:size(f_pos) > 1 then
+ error("Cannot convert shape "..self.inputShape.." to b when feature > 1")
+ end
+ return nn.Select(f_pos, 1)
+ else
+ error("Cannot convert shape "..self.inputShape.." to shape b")
+ end
+end
+
+-- returns the current shape of the data
+function Convert:default()
+ return nn.Identity()
+end
+
+-- multi-class (batch target)
+function Convert:bt()
+ local b_pos = self:findAxis('b')
+ local modula
+ if self.inputShape == 'b' then
+ modula = nn.Reshape(1)
+ else
+ error("cannot convert shape '"..self.inputShape.."' to bt")
+ end
+ return modula
+end
+
+-- a generic function for transposing shape axes
+function Convert:transpose(newShape)
+ if newShape == self.inputShape then
+ return nn.Identity()
+ end
+ local inputShape = {}
+ for i=1,#self.inputShape do
+ table.insert(inputShape, self.inputShape:sub(i,i))
+ end
+ local transpositions = {}
+ for i=1,#newShape do
+ local j = _.indexOf(inputShape, newShape:sub(i,i))
+ if i ~= j then
+ local char = inputShape[i]
+ inputShape[i] = inputShape[j]
+ inputShape[j] = char
+ table.insert(transpositions, {j, i})
+ end
+ end
+ return nn.Transpose(table.unpack(transpositions))
+end
+
+function Convert:findAxis(axis_char, shape, silent)
+ shape = shape or self.inputShape
+ local axis_pos = shape:find(axis_char)
+ if (not silent) and (not axis_pos) then
+ error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2)
+ end
+ return axis_pos
+end
+
+function Convert:clearState()
+ self._input = nil
+ self._gradInput = nil
+ self.__input = nil
+ self.__output = nil
+ self.__gradInput = nil
+ self.__gradOutput = nil
+end
+
+function Convert:type(type)
+ self:clearState()
+ return parent.type(self, type)
+end