aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/PartialLinear.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/PartialLinear.lua')
-rw-r--r--contrib/lua-torch/nn/PartialLinear.lua114
1 files changed, 114 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/PartialLinear.lua b/contrib/lua-torch/nn/PartialLinear.lua
new file mode 100644
index 000000000..6e92cfc08
--- /dev/null
+++ b/contrib/lua-torch/nn/PartialLinear.lua
@@ -0,0 +1,114 @@
+local PartialLinear, Module = torch.class('nn.PartialLinear', 'nn.Module')
+
+--[[
+
+PartialLinear is a Linear layer that allows the user to a set a collection of
+column indices. When the column indices are set, the layer will behave like a
+Linear layer that only has those columns. Meanwhile, all parameters are
+preserved, so resetting the PartialLinear layer will result in a module that
+behaves just like a regular Linear layer.
+
+This module is useful, for instance, when you want to do forward-backward on
+only a subset of a Linear layer during training but use the full Linear layer
+at test time.
+
+]]--
+
+function PartialLinear:__init(inputsize, outputsize, bias)
+ local bias = ((bias == nil) and true) or bias
+ Module.__init(self)
+
+ -- define the layer as a small network:
+ local pt = nn.ParallelTable()
+ pt:add(nn.Identity()):add(nn.LookupTable(outputsize, inputsize))
+ self.network = nn.Sequential():add(pt):add(nn.MM(false, true))
+ if bias then
+ self.bias = torch.Tensor(1, outputsize):zero()
+ self.gradBias = torch.Tensor(1, outputsize):zero()
+ end
+
+ -- set partition:
+ self.inputsize = inputsize
+ self.outputsize = outputsize
+ self.allcolumns = torch.range(1, self.outputsize)
+ self:resetPartition()
+end
+
+function PartialLinear:setPartition(indices)
+ self.partition = indices:type(self.allcolumns:type())
+end
+
+function PartialLinear:resetPartition()
+ self.partition = self.allcolumns
+end
+
+function PartialLinear:parameters()
+ return {self.network:get(1):get(2).weight, self.bias},
+ {self.network:get(1):get(2).gradWeight, self.gradBias}
+end -- should return only the relevant partition?
+
+function PartialLinear:updateOutput(input)
+ self.output:set(self.network:forward{input, self.partition})
+ if self.bias then
+ self.output:add(
+ self.bias:index(2, self.partition:long()):expandAs(self.output)
+ )
+ self.addBuffer = self.addBuffer or input.new()
+ if self.addBuffer:nElement() ~= input:size(1) then
+ self.addBuffer:resize(input:size(1)):fill(1)
+ end
+ end
+ return self.output
+end
+
+function PartialLinear:updateGradInput(input, gradOutput)
+ if self.gradInput then
+ self.network:updateGradInput({input, self.partition}, gradOutput)
+ self.gradInput:set(self.network.gradInput[1])
+ end
+ return self.gradInput
+end
+
+function PartialLinear:accGradParameters(input, gradOutput, scale)
+ local scale = scale or 1
+ self.network:accGradParameters({input, self.partition}, gradOutput, scale)
+ if self.bias then
+ self.buffer = self.buffer or input.new()
+ self.buffer:resize(gradOutput:size(2))
+ self.buffer:mv(gradOutput:t(), self.addBuffer):mul(scale)
+ self.gradBias:indexAdd(
+ 2, self.partition:long(), self.buffer:view(1, self.buffer:nElement())
+ )
+ end
+end
+
+function PartialLinear:accUpdateGradParameters(input, gradOutput, lr)
+ local gradWeight = self.network:get(1):get(2).gradWeight
+ local gradBias = self.gradBias
+ self.network:get(1):get(2).gradWeight = self.network:get(1):get(2).weight
+ self.gradBias = self.bias
+ self:accGradParameters(input, gradOutput, -lr)
+ self.network:get(1):get(2).gradWeight = gradWeight
+ self.gradBias = gradBias
+end
+
+function PartialLinear:zeroGradParameters()
+ self.network:zeroGradParameters()
+ self.gradBias:zero()
+end
+
+function PartialLinear:updateParameters(learningRate)
+ self.network:updateParameters(learningRate)
+ self.bias:add(-learningRate, self.gradBias)
+end
+
+function PartialLinear:sharedAccUpdateGradParameters(input, gradOutput, lr)
+ -- we do not need to accumulate parameters when sharing:
+ self:defaultAccUpdateGradParameters(input, gradOutput, lr)
+end
+
+function PartialLinear:__tostring__()
+ return torch.type(self) ..
+ string.format('(%d -> %d)', self.inputsize, self.outputsize) ..
+ (self.bias == nil and ' without bias' or '')
+end