diff options
Diffstat (limited to 'contrib/lua-torch/nn/IndexLinear.lua')
-rw-r--r-- | contrib/lua-torch/nn/IndexLinear.lua | 398 |
1 files changed, 398 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/IndexLinear.lua b/contrib/lua-torch/nn/IndexLinear.lua new file mode 100644 index 000000000..928e5d3f2 --- /dev/null +++ b/contrib/lua-torch/nn/IndexLinear.lua @@ -0,0 +1,398 @@ +local ffi = require 'ffi' +local IndexLinear, parent = torch.class('nn.IndexLinear', 'nn.Module') + + + +function IndexLinear:__init(inputSize, outputSize, doGradInput, keysOffset, weight, bias, normalize) + parent.__init(self) + + -- We need for 3 extra parameters per feature + -- if we normalize: + -- * The max-abs value + -- * The inverse of the max-abs value + -- * The per-feature bias + -- We keep an extra placeholder for further per learning rate feature manipulation. + -- So it's 4 total. + self.normalize = normalize and 4 or 0 + + -- This is important to keep the possibility of sharing a weight + -- directly, without having to allocate it first. + -- The reason is these weights can be very large. + self.weight = weight or torch.Tensor(inputSize, outputSize + self.normalize):zero() + self.bias = bias or torch.Tensor(outputSize):zero() + self.inputSize = self.weight and self.weight:size(1) or inputSize + self.outputSize = self.weight and (self.weight:size(2)-self.normalize) or outputSize + + -- gradWeight is not initialized as we're doing dense gradient accumulation + -- This is more efficient and avoids allocating a giant useless gradWeight + self.gradWeight = torch.Tensor() + + -- gradBias still works the same as it's already dense + self.gradBias = torch.Tensor(self.outputSize):zero() + + -- Buffers + self.gradWeightBuffer = torch.Tensor() + self.valuesBuffer = torch.Tensor() + self.normalizedValues = torch.Tensor() + + -- That is used to accumulate keys and gradWeight + -- when doing gradients accumulations + self.running = { + cumSumSizes = {}, + keys = {}, + gradWeight = {}, + counter = 1, + } + + -- self.sizes, self.cumSumSizes are calculated on the CPU even when using CUDA. + -- These two tables make it easier to resize these buffers instead of re-allocating them. + -- self.*Cache[1] always contains values on CPU. + -- If CUDA is being used, self.*Cache[2] contains values on GPU. + self.sizesCache = {} + self.cumSumSizesCache = {} + + -- A few options + self.weightDecay = 0 + self.doGradInput = doGradInput or false + self.offset = keysOffset and keysOffset-1 or -1 -- if this adds self.offset to indices +end + +-- Reset all the parameters needed +-- for normalization to 0 +function IndexLinear:reset(stdv) + if stdv then + stdv = stdv * math.sqrt(3) + else + stdv = 1./math.sqrt(self.weight:size(2)) + end + self.weight:uniform(-stdv, stdv) + self.bias:uniform(-stdv, stdv):mul(0.000001) + if self.normalize and self.normalize > 0 then + self.weight[{{}, {1,self.normalize}}]:zero() + end +end + +function IndexLinear:reshapeInput(input) + assert(type(input) == 'table') + + local ninputs = 0 + for _, v in ipairs(input) do + ninputs = ninputs + 1 + end + + assert(ninputs == 2 or ninputs == 3) + + -- If format is: + -- { + -- torch.LongTensor(size1+size2+...+sizeN), -- concatenated batch of keys + -- torch.Tensor(size1+size2+...+sizeN), -- concatenated batch of values + -- torch.LongTensor(N), -- keys/values sizes (values are {size1, ..., sizeN}) + -- } + if ninputs == 3 then + local fkeys = input[1] + local fvals = input[2] + local fsizes = torch.isTensor(input[3]) and input[3] or fkeys.new{input[3]} + assert(fkeys:nElement() == fvals:nElement(), 'Keys and values should be of same size') + assert(fkeys:dim() == 1, 'Keys and values should be 1D') + self.isFlat = true + self.noBatch = false + return fkeys, fvals, fsizes + end + + local keys = input[1] + local values = input[2] + local lkeys, lvalues + + -- If format is: + -- { + -- { torch.LongTensor(size1), torch.LongTensor(size2), ..., torch.LongTensor(sizeN) }, -- batch of keys + -- { torch.Tensor(size1), torch.Tensor(size2), ..., torch.Tensor(sizeN) }, -- batch of values, + -- } + if type(keys) == 'table' and type(values) == 'table' then + lkeys, lvalues = keys, values + self.isFlat = false + self.noBatch = false + + -- If format is not a batch: + -- { + -- torch.LongTensor(size1), -- keys + -- torch.Tensor(size1), -- values, + -- } + elseif torch.isTensor(keys) and torch.isTensor(values) then + lkeys, lvalues = {keys}, {values} + self.isFlat = false + self.noBatch = true + else + error('Wrong input format.') + end + + for i=1,#lkeys do + assert(lvalues[i]:dim() == 1 and lkeys[i]:dim() == 1, "keys and values should be 1D") + end + + return lkeys, lvalues +end + +function IndexLinear:longTensor(...) + if (self:type() == 'torch.CudaTensor') then + return torch.CudaLongTensor(...) + else + return torch.LongTensor(...) + end +end + +function IndexLinear:flattenInputs(input) + local lkeys, lvalues, sizes = self:reshapeInput(input) + + local counter = self.running.counter + + -- Ensure everything is of the right type + local isCuda = (self:type() == 'torch.CudaTensor') + self.running.keys[counter] = self.running.keys[counter] or self:longTensor() + self.keys = self.running.keys[counter] + + if self.isFlat then + self.values = self.values or lvalues.new() + self.sizes = self.sizes or self:longTensor() + + self.keys:resize(lkeys:size()):copy(lkeys) + self.values:resize(lvalues:size()):copy(lvalues) + self.sizes = sizes + self.cumSumSizes = self.cumSumSizes or self.sizes.new() + self.cumSumSizes:cumsum(self.sizes) + else + self.values = self.values or lvalues[1].new() + + self.lkeys = lkeys + self.lvalues = lvalues + local batchSize = #self.lkeys + + self.sizesCache[1] = self.sizesCache[1] or torch.LongTensor(batchSize) + self.cumSumSizesCache[1] = self.cumSumSizesCache[1] or torch.LongTensor(batchSize) + + self.sizes = self.sizesCache[1] + self.cumSumSizes = self.cumSumSizesCache[1] + + self.sizes:resize(batchSize) + self.cumSumSizes:resize(batchSize) + + for i = 1,batchSize do + self.sizes[i] = self.lkeys[i]:size(1) + end + self.cumSumSizes:cumsum(self.sizes) + + self.keys:cat(self.lkeys, 1) + self.values:cat(self.lvalues, 1) + + if isCuda then + -- Get the GPU cache + self.sizesCache[2] = self.sizesCache[2] or torch.CudaLongTensor() + self.cumSumSizesCache[2] = self.cumSumSizesCache[2] or torch.CudaLongTensor() + + self.sizes = self.sizesCache[2] + self.cumSumSizes = self.cumSumSizesCache[2] + + -- Resize and copy to GPU + self.sizes:resize(batchSize):copy(self.sizesCache[1]) + self.cumSumSizes:resize(batchSize):copy(self.cumSumSizesCache[1]) + end + end + self.running.cumSumSizes[counter] = self.cumSumSizes +end + +function IndexLinear:updateOutput(input) + + self:flattenInputs(input) + + self.values.THNN.IndexLinear_updateOutput( + self.keys:cdata(), + self.offset, + self.values:cdata(), + self.sizes:cdata(), + self.cumSumSizes:cdata(), + self.output:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.normalizedValues:cdata(), + self.train and 1 or 0 + ) + + if self.noBatch then + self.output:resize(self.output:size(2)) + end + return self.output +end + +function IndexLinear:accUpdateGradParameters(input, gradOutput, scale) + self.values.THNN.IndexLinear_accUpdateGradParameters( + self.keys:cdata(), + self.offset, + self.normalize > 0 and self.normalizedValues:cdata() or self.values:cdata(), + self.sizes:cdata(), + self.cumSumSizes:cdata(), + gradOutput:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.weightDecay or 0, + scale or 1 + ) +end + +function IndexLinear:accGradParameters(input, gradOutput, scale) + + local counter = self.running.counter + + -- Same as the running.keys in the updateOutput function, + -- get a table of dense running.gradWeight + self.running.gradWeight[counter] = self.running.gradWeight[counter] or self.values.new() + self.values.THNN.IndexLinear_accGradParameters( + self.keys:cdata(), + self.offset, + self.normalize > 0 and self.normalizedValues:cdata() or self.values:cdata(), + self.sizes:cdata(), + self.cumSumSizes:cdata(), + gradOutput:cdata(), + self.running.gradWeight[counter]:cdata(), + self.gradBias:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.valuesBuffer:cdata(), + self.weightDecay or 0, + scale or 1 + ) + + -- Increment the running counter to create a new buffer + -- if we don't flush them in zerogradParameters + self.running.counter = self.running.counter + 1 +end + +function IndexLinear:updateGradInput(input, gradOutput) + self.gradInput = {} + -- Revamped from nn.SparseLinear.updateGradInput + if self.doGradInput and self.normalize > 0 then + error('updateGradInput is not implemented in max-normalize mode') + end + + local ini = self.weight:size(1) + + if self.doGradInput then + local gi = gradOutput.new() + if gradOutput:dim() == 1 then + gi:resize(self.weight:size(1)) + gi:mv(self.weight,gradOutput) + gi:resize(1, self.weight:size(1)) + elseif gradOutput:dim() == 2 then + gi:resize(gradOutput:size(1), self.weight:size(1)) + gi:mm(gradOutput, self.weight:t()) + end + + local indices = self.running.keys[1].new(ini):range(1, ini) + + if self.isFlat then + self.gradInput[1] = torch.repeatTensor(indices, gi:size(1), 1) + self.gradInput[2] = gi + else + self.gradInput[1] = {} + self.gradInput[2] = {} + for i = 1,gi:size(1) do + self.gradInput[1][i] = self.running.keys[1].new(ini) + self.gradInput[1][i]:copy(indices) + self.gradInput[2][i] = gradOutput.new(ini) + self.gradInput[2][i]:copy(gi[i]) + end + end + end + + if self.noBatch then + if self.isFlat then + self.gradInput = {self.gradInput[1]:resize(ini), self.gradInput[2]:resize(ini)} + else + self.gradInput = {self.gradInput[1][1], self.gradInput[2][1]} + end + end + return self.gradInput +end + +function IndexLinear:updateParameters(lr) + local counter = self.running.counter + if counter > 1 then + if counter == 2 then + self.updateKeys = self.running.keys[1] + self.gradWeight = self.running.gradWeight[1] + else + self.updateKeysBuffer = self.updateKeysBuffer or self:longTensor() + local lkeys = {} + local lgweights = {} + local totalSize = 0 + local lCumSumSizes = {} + for i=1,counter-1 do + lkeys[i] = self.running.keys[i] + -- Change layout to take advantage of the 1-D contiguous torch.cat + lgweights[i] = self.running.gradWeight[i]:contiguous() + lgweights[i]:resize(lgweights[i]:nElement()) + lCumSumSizes[i] = totalSize + self.running.cumSumSizes[i] + totalSize = totalSize + lkeys[i]:size(1) + end + + self.updateKeysBuffer:cat(lkeys, 1) + self.gradWeightBuffer:cat(lgweights, 1) + self.cumSumSizes:cat(lCumSumSizes, 1) + self.gradWeightBuffer:resize(totalSize, self.outputSize) + self.gradWeight = self.gradWeightBuffer + self.updateKeys = self.updateKeysBuffer + end + self.values.THNN.IndexLinear_updateParameters( + self.gradWeight:cdata(), + self.gradBias:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.updateKeys:cdata(), + self.cumSumSizes:cdata(), + self.offset, + self.weightDecay or 0, + lr or error('You must specify a learning rate') + ) + end +end + +function IndexLinear:zeroGradParameters() + -- No need to do anything here as gradWeight is dense + self.gradBias:zero() + + -- The below piece of code would reset + -- the smart scaling parameters for each features + -- each time we call zeroGradParameters + -- TODO: decide what to do with that piece of code. + -- NB: this should be commented along with the corresponding + -- piece of code in lib/THNN/generic/IndexLinear.c, in the accUpdateGradParameters function. + + --[[ + local w = self.weight:select(2, 3) + if self.updateKeys and self.updateKeys:nElement() > 0 then + self.updateKeysBuffer:resizeAs(self.updateKeys):copy(self.updateKeys):add(self.offset+1) + w:indexFill(1, self.updateKeysBuffer, 0) + end + ]]-- + self.running.counter = 1 +end + +function IndexLinear:parameters() + return {self.weight, self.bias}, {self.running, self.gradBias} +end + +function IndexLinear:clearState() + self.running.keys = {} + self.running.gradWeight = {} + self.keys = nil + self.zerokeys = nil + self.updateKeys = nil + self.values = nil + self.sizes = nil + self.lkeys = {} + self.lvalues = {} + self.gradWeightBuffer = self.gradWeightBuffer.new() + self.valuesBuffer = self.valuesBuffer.new() + self.updateKeysBuffer = nil + self.values = nil + return parent.clearState(self) +end |