diff options
Diffstat (limited to 'contrib/lua-torch/nn/Convert.lua')
-rw-r--r-- | contrib/lua-torch/nn/Convert.lua | 245 |
1 files changed, 245 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/Convert.lua b/contrib/lua-torch/nn/Convert.lua new file mode 100644 index 000000000..855338dd6 --- /dev/null +++ b/contrib/lua-torch/nn/Convert.lua @@ -0,0 +1,245 @@ +------------------------------------------------------------------------ +--[ nn.Convert ]-- +-- Module to convert between different data formats +-- nn.Convert('bchw', 'bf') or nn.Convert('chw', 'f') +-- Automatically converts input to same type as self.output +-- Simplest use is for automatic input type converions : nn.Convert() +------------------------------------------------------------------------ +local _ = require 'moses' +local Convert, parent = torch.class("nn.Convert", "nn.Container") + +function Convert:__init(inputShape, outputShape) + if outputShape and not inputShape then + error"Expecting non-nil arg 1 when arg 2 is provided" + end + inputShape = inputShape or 'b*' + outputShape = outputShape or inputShape + self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape) + self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape) + self.inputBatchDim = self.inputShape:find('b') + self.outputBatchDim = self.outputShape:find('b') + if self.inputShape == 'b*' or self.outputShape == 'b*' then + assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*') + self.nInputDim = -1 + self.nOutputDim = -1 + self.transposition = true + else + -- number of dims in batch mode + self.nInputDim = #self.inputShape + self.nOutputDim = #self.outputShape + -- is the outputShape just a transposition of the inputShape? + if self.nInputDim == self.nOutputDim then + self.transposition = true + for i=1,self.nInputDim do + if not self.outputShape:find(self.inputShape:sub(i,i)) then + self.transposition = false + break + end + end + end + end + parent.__init(self) +end + +-- post-initialization +function Convert:buildConverter(input) + if self.transposition then + self.converter = self:transpose(self.outputShape) + else + if (torch.type(self[self.outputShape]) ~= 'function') then + error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape)) + end + self.converter = self[self.outputShape](self, input) + end + assert(torch.isTensor(self.output), "Expecting Tensor output") + + self.converter:type(torch.type(self.output)) + + self.modules[1] = self.converter +end + +function Convert:updateOutput(input) + assert(torch.isTensor(input), "expecting Tensor") + if not torch.isTypeOf(input, torch.type(self.output)) then + -- handle different input type + self._input = self._input or self.output.new() + self._input:resize(input:size()):copy(input) + input = self._input + end + self.batchMode = true + if input:dim() < self.nInputDim then + -- handle non-batch mode + local inputSize = input:size():totable() + table.insert(inputSize, self.inputBatchDim, 1) + self.__input = self.__input or input.new() + self.__input:set(input):resize(table.unpack(inputSize)) + input = self.__input + self.batchMode = false + end + if not self.converter then + self:buildConverter(input) + end + + self.output = self.converter:updateOutput(input) + + if not self.batchMode then + local outputSize = self.output:size():totable() + table.remove(outputSize, self.outputBatchDim) + self.__output = self.__output or self.output.new() + self.__output:set(self.output):resize(table.unpack(outputSize)) + self.output = self.__output + end + return self.output +end + +function Convert:updateGradInput(input, gradOutput) + local input_ = input + input = self._input or input + if not self.batchMode then + input = self.__input + self.__gradOutput = self.__gradOutput or gradOutput.new() + self.__gradOutput:set(gradOutput):resize(self.converter.output:size()) + gradOutput = self.__gradOutput + end + + local gradInput = self.converter:updateGradInput(input, gradOutput) + + if not self.batchMode then + self.__gradInput = self.__gradInput or gradInput.new() + self.__gradInput:set(gradInput):resize(input_:size()) + gradInput = self.__gradInput + end + if self._input then + self._gradInput = self._gradInput or input.new() + self._gradInput:resize(input:size()):copy(gradInput) + self.gradInput = self._gradInput + else + self.gradInput = gradInput + end + + return self.gradInput +end + +function Convert:accGradParameters(input, gradOutput, scale) + input = self.batchMode and self.__input or self._input or input + gradOutput = self.batchMode and self.__gradOutput or gradOutput + self.converter:accGradParameters(input, gradOutput, scale) +end + +function Convert:accUpdateGradParameters(input, gradOutput, lr) + input = self.batchMode and self.__input or self._input or input + gradOutput = self.batchMode and self.__gradOutput or gradOutput + self.converter:accUpdateGradParameters(input, gradOutput, lr) +end + +-- batch feature +function Convert:bf(input) + local b_pos = self:findAxis('b', self.inputShape) + local dim = #self.inputShape + if self.inputShape == 'bt' then + error"Conversion of shape bt to bf not supported: open an issue on github" + end + -- was b + if dim == 1 then + return nn.Reshape(1) + end + -- was b... + local modula + if b_pos ~= 1 then + modula = nn.Transpose({1, b_pos}) + end + if dim > 2 then + local transpose = modula + local sampleSize = input:select(self:findAxis('b'),1):nElement() + local reshape = nn.Reshape(sampleSize) + if transpose then + modula = nn.Sequential() + modula:add(transpose) + modula:add(reshape) + else + modula = reshape + end + end + return modula or nn.Identity() +end + +-- each example is a scalar; batch is a vector +function Convert:b(input) + local b_pos = self:findAxis('b') + if self.inputShape == 'bt' or self.inputShape == 'tb' then + local t_pos = self:findAxis('t') + -- select first set of classes + return nn.Select(t_pos, 1) + elseif self.inputShape == 'bf' or self.inputShape == 'fb' then + -- this wont work as expected with size(f) > 1 + local f_pos = self:findAxis('f') + if input:size(f_pos) > 1 then + error("Cannot convert shape "..self.inputShape.." to b when feature > 1") + end + return nn.Select(f_pos, 1) + else + error("Cannot convert shape "..self.inputShape.." to shape b") + end +end + +-- returns the current shape of the data +function Convert:default() + return nn.Identity() +end + +-- multi-class (batch target) +function Convert:bt() + local b_pos = self:findAxis('b') + local modula + if self.inputShape == 'b' then + modula = nn.Reshape(1) + else + error("cannot convert shape '"..self.inputShape.."' to bt") + end + return modula +end + +-- a generic function for transposing shape axes +function Convert:transpose(newShape) + if newShape == self.inputShape then + return nn.Identity() + end + local inputShape = {} + for i=1,#self.inputShape do + table.insert(inputShape, self.inputShape:sub(i,i)) + end + local transpositions = {} + for i=1,#newShape do + local j = _.indexOf(inputShape, newShape:sub(i,i)) + if i ~= j then + local char = inputShape[i] + inputShape[i] = inputShape[j] + inputShape[j] = char + table.insert(transpositions, {j, i}) + end + end + return nn.Transpose(table.unpack(transpositions)) +end + +function Convert:findAxis(axis_char, shape, silent) + shape = shape or self.inputShape + local axis_pos = shape:find(axis_char) + if (not silent) and (not axis_pos) then + error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2) + end + return axis_pos +end + +function Convert:clearState() + self._input = nil + self._gradInput = nil + self.__input = nil + self.__output = nil + self.__gradInput = nil + self.__gradOutput = nil +end + +function Convert:type(type) + self:clearState() + return parent.type(self, type) +end |