diff options
Diffstat (limited to 'contrib/lua-torch/nn/PartialLinear.lua')
-rw-r--r-- | contrib/lua-torch/nn/PartialLinear.lua | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/PartialLinear.lua b/contrib/lua-torch/nn/PartialLinear.lua new file mode 100644 index 000000000..6e92cfc08 --- /dev/null +++ b/contrib/lua-torch/nn/PartialLinear.lua @@ -0,0 +1,114 @@ +local PartialLinear, Module = torch.class('nn.PartialLinear', 'nn.Module') + +--[[ + +PartialLinear is a Linear layer that allows the user to a set a collection of +column indices. When the column indices are set, the layer will behave like a +Linear layer that only has those columns. Meanwhile, all parameters are +preserved, so resetting the PartialLinear layer will result in a module that +behaves just like a regular Linear layer. + +This module is useful, for instance, when you want to do forward-backward on +only a subset of a Linear layer during training but use the full Linear layer +at test time. + +]]-- + +function PartialLinear:__init(inputsize, outputsize, bias) + local bias = ((bias == nil) and true) or bias + Module.__init(self) + + -- define the layer as a small network: + local pt = nn.ParallelTable() + pt:add(nn.Identity()):add(nn.LookupTable(outputsize, inputsize)) + self.network = nn.Sequential():add(pt):add(nn.MM(false, true)) + if bias then + self.bias = torch.Tensor(1, outputsize):zero() + self.gradBias = torch.Tensor(1, outputsize):zero() + end + + -- set partition: + self.inputsize = inputsize + self.outputsize = outputsize + self.allcolumns = torch.range(1, self.outputsize) + self:resetPartition() +end + +function PartialLinear:setPartition(indices) + self.partition = indices:type(self.allcolumns:type()) +end + +function PartialLinear:resetPartition() + self.partition = self.allcolumns +end + +function PartialLinear:parameters() + return {self.network:get(1):get(2).weight, self.bias}, + {self.network:get(1):get(2).gradWeight, self.gradBias} +end -- should return only the relevant partition? + +function PartialLinear:updateOutput(input) + self.output:set(self.network:forward{input, self.partition}) + if self.bias then + self.output:add( + self.bias:index(2, self.partition:long()):expandAs(self.output) + ) + self.addBuffer = self.addBuffer or input.new() + if self.addBuffer:nElement() ~= input:size(1) then + self.addBuffer:resize(input:size(1)):fill(1) + end + end + return self.output +end + +function PartialLinear:updateGradInput(input, gradOutput) + if self.gradInput then + self.network:updateGradInput({input, self.partition}, gradOutput) + self.gradInput:set(self.network.gradInput[1]) + end + return self.gradInput +end + +function PartialLinear:accGradParameters(input, gradOutput, scale) + local scale = scale or 1 + self.network:accGradParameters({input, self.partition}, gradOutput, scale) + if self.bias then + self.buffer = self.buffer or input.new() + self.buffer:resize(gradOutput:size(2)) + self.buffer:mv(gradOutput:t(), self.addBuffer):mul(scale) + self.gradBias:indexAdd( + 2, self.partition:long(), self.buffer:view(1, self.buffer:nElement()) + ) + end +end + +function PartialLinear:accUpdateGradParameters(input, gradOutput, lr) + local gradWeight = self.network:get(1):get(2).gradWeight + local gradBias = self.gradBias + self.network:get(1):get(2).gradWeight = self.network:get(1):get(2).weight + self.gradBias = self.bias + self:accGradParameters(input, gradOutput, -lr) + self.network:get(1):get(2).gradWeight = gradWeight + self.gradBias = gradBias +end + +function PartialLinear:zeroGradParameters() + self.network:zeroGradParameters() + self.gradBias:zero() +end + +function PartialLinear:updateParameters(learningRate) + self.network:updateParameters(learningRate) + self.bias:add(-learningRate, self.gradBias) +end + +function PartialLinear:sharedAccUpdateGradParameters(input, gradOutput, lr) + -- we do not need to accumulate parameters when sharing: + self:defaultAccUpdateGradParameters(input, gradOutput, lr) +end + +function PartialLinear:__tostring__() + return torch.type(self) .. + string.format('(%d -> %d)', self.inputsize, self.outputsize) .. + (self.bias == nil and ' without bias' or '') +end |