diff options
Diffstat (limited to 'contrib/torch/nn/Jacobian.lua')
-rw-r--r-- | contrib/torch/nn/Jacobian.lua | 389 |
1 files changed, 389 insertions, 0 deletions
diff --git a/contrib/torch/nn/Jacobian.lua b/contrib/torch/nn/Jacobian.lua new file mode 100644 index 000000000..4f728b18c --- /dev/null +++ b/contrib/torch/nn/Jacobian.lua @@ -0,0 +1,389 @@ +nn.Jacobian = {} + +function nn.Jacobian.backward(module, input, param, dparam) + local doparam = 0 + if param then + doparam = 1 + end + param = param or input + -- output deriv + module:forward(input) + local dout = module.output.new():resizeAs(module.output) + -- 1D view + local sdout = module.output.new(dout:storage(),1,dout:nElement()) + -- jacobian matrix to calculate + local jacobian = torch.Tensor(param:nElement(),dout:nElement()):zero() + + for i=1,sdout:nElement() do + dout:zero() + sdout[i] = 1 + module:zeroGradParameters() + local din = module:updateGradInput(input, dout) + module:accGradParameters(input, dout) + if doparam == 1 then + jacobian:select(2,i):copy(dparam) + else + jacobian:select(2,i):copy(din) + end + end + return jacobian +end + +function nn.Jacobian.backwardUpdate(module, input, param) + + -- output deriv + module:forward(input) + local dout = module.output.new():resizeAs(module.output) + -- 1D view + local sdout = module.output.new(dout:storage(),1,dout:nElement()) + -- jacobian matrix to calculate + local jacobian = torch.Tensor(param:nElement(),dout:nElement()):zero() + + -- original param + local params = module:parameters() + local origparams = {} + for j=1,#params do + table.insert(origparams, params[j]:clone()) + end + + for i=1,sdout:nElement() do + for j=1,#params do + params[j]:copy(origparams[j]) + end + dout:zero() + sdout[i] = 1 + module:updateGradInput(input, dout) + module:accUpdateGradParameters(input, dout, 1) + jacobian:select(2,i):copy(param) + end + + for j=1,#params do + params[j]:copy(origparams[j]) + end + + return jacobian +end + +function nn.Jacobian.forward(module, input, param, perturbation) + param = param or input + -- perturbation amount + perturbation = perturbation or 1e-6 + -- 1D view of input + --local tst = param:storage() + local sin = param.new(param):resize(param:nElement())--param.new(tst,1,tst:size()) + -- jacobian matrix to calculate + local jacobian = torch.Tensor():resize(param:nElement(),module:forward(input):nElement()) + + local outa = torch.Tensor(jacobian:size(2)) + local outb = torch.Tensor(jacobian:size(2)) + + for i=1,sin:nElement() do + local orig = sin[i] + sin[i] = orig - perturbation + outa:copy(module:forward(input)) + sin[i] = orig + perturbation + outb:copy(module:forward(input)) + sin[i] = orig + + outb:add(-1,outa):div(2*perturbation) + jacobian:select(1,i):copy(outb) + end + + return jacobian +end + +function nn.Jacobian.backwardDiagHessian(module, input, diagHessianParamName) + -- Compute the second derivatives (diagonal Hessian elements) + -- by backpropagation (using the code from hessian.lua). + -- + -- This function computes the diagonal Hessian elements of the following function: + -- + -- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2, + -- + -- where + -- x_1, ..., x_n are the input values and parameters of the given module, + -- y_1, ..., y_m are the output values of the given module. + -- + -- All x_i and y_i values are scalars here. In other words, + -- x_1, ..., x_n denote the scalar elements of the module input tensor, + -- the scalar elements of module.weight, + -- and the scalar elements of module.bias; + -- y_1, ..., y_m are the scalar elements of the module output tensor. + -- + -- The diagonal Hessian elements of F are computed with respect to + -- the module input values and parameters (x_1, .., x_n). + -- + -- The function F is chosen for its convenient properties: + -- + -- dF / dy_i = y_i, + -- d^2F / dy_i^2 = 1. + -- + -- In other words, the diagonal Hessian elements of F with respect + -- to the module OUTPUT values (y_1, ... y_m) are equal to 1. + -- + -- Because of that, computing the diagonal Hessian elements of F + -- with respect to the module INPUT values and PARAMETERS (x_1, ..., x_n) + -- can be done by calling updateDiagHessianInput() and accDiagHessianParameters() + -- using a tensor of ones as diagHessianOutput. + + module:forward(input) + local diagHessianOutput = module.output.new():resizeAs(module.output):fill(1) + + module.diagHessianWeight:zero() + module.diagHessianBias:zero() + module:updateDiagHessianInput(input, diagHessianOutput) + module:accDiagHessianParameters(input, diagHessianOutput) + + return module[diagHessianParamName] +end + +function nn.Jacobian.linearModuleDiagHessian(module, input, gradParamName) + -- Compute the second derivatives (diagonal Hessian elements) + -- from the first derivatives for the given module + -- (without using the code from hessian.lua). + -- + -- The given module is assumed to be linear with respect to its inputs and weights + -- (like nn.Linear, nn.SpatialConvolution, etc.) + -- + -- This function computes the diagonal Hessian elements of the following function: + -- + -- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2. + -- + -- (See the the comment for nn.Jacobian.backwardDiagHessian() for explanation.) + -- + -- The first derivatives of F with respect to + -- the module inputs and parameters (x_1, ..., x_n) are: + -- + -- dF / dx_i = \sum_k (dF / dy_k) (dy_k / dx_i). + -- + -- The second derivatives are: + -- + -- d^2F / dx_i = \sum_k [(d^2F / dy_k^2) (dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)]. + -- + -- The second derivatives of F with respect to the module outputs (y_1, ..., y_m) + -- are equal to 1, so: + -- + -- d^2F / dx_i = \sum_k [(dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)]. + -- + -- Assuming the linearity of module outputs (y_1, ..., y_m) + -- with respect to module inputs and parameters (x_1, ..., x_n), + -- we have (d^2y_k / dx_i^2) = 0, + -- and the expression finally becomes: + -- + -- d^2F / dx_i = \sum_k (dy_k / dx_i)^2. + -- + -- The first derivatives (dy_k / dx_i) are computed by normal backpropagation, + -- using updateGradInput() and accGradParameters(). + + local gradParam = module[gradParamName] + + local diagHessian = gradParam.new():resize(gradParam:nElement()):zero() + + module:forward(input) + local gradOutput = module.output.new():resizeAs(module.output) + local gradOutput1D = gradOutput:view(gradOutput:nElement()) + + for i=1,gradOutput:nElement() do + gradOutput1D:zero() + gradOutput1D[i] = 1 + module.gradWeight:zero() + if module.bias then + module.gradBias:zero() + end + module:updateGradInput(input, gradOutput) + module:accGradParameters(input, gradOutput) + diagHessian:addcmul(gradParam, gradParam) + end + + return diagHessian +end + +function nn.Jacobian.forwardUpdate(module, input, param, perturbation) + -- perturbation amount + perturbation = perturbation or 1e-6 + -- 1D view of input + --local tst = param:storage() + local sin = param.new(param):resize(param:nElement())--param.new(tst,1,tst:size()) + -- jacobian matrix to calculate + local jacobian = torch.Tensor():resize(param:nElement(),module:forward(input):nElement()) + + local outa = torch.Tensor(jacobian:size(2)) + local outb = torch.Tensor(jacobian:size(2)) + + for i=1,sin:nElement() do + local orig = sin[i] + sin[i] = orig - perturbation + outa:copy(module:forward(input)) + sin[i] = orig + perturbation + outb:copy(module:forward(input)) + sin[i] = orig + + outb:add(-1,outa):div(2*perturbation) + jacobian:select(1,i):copy(outb) + jacobian:select(1,i):mul(-1) + jacobian:select(1,i):add(sin[i]) + end + return jacobian +end + +function nn.Jacobian.testJacobian(module, input, minval, maxval, perturbation) + minval = minval or -2 + maxval = maxval or 2 + local inrange = maxval - minval + input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) + local jac_fprop = nn.Jacobian.forward(module, input, input, perturbation) + local jac_bprop = nn.Jacobian.backward(module, input) + local error = jac_fprop-jac_bprop + return error:abs():max() +end + +function nn.Jacobian.testJacobianParameters(module, input, param, dparam, minval, maxval, perturbation) + minval = minval or -2 + maxval = maxval or 2 + local inrange = maxval - minval + input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) + param:copy(torch.rand(param:nElement()):mul(inrange):add(minval)) + local jac_bprop = nn.Jacobian.backward(module, input, param, dparam) + local jac_fprop = nn.Jacobian.forward(module, input, param, perturbation) + local error = jac_fprop - jac_bprop + return error:abs():max() +end + +function nn.Jacobian.testJacobianUpdateParameters(module, input, param, minval, maxval, perturbation) + minval = minval or -2 + maxval = maxval or 2 + local inrange = maxval - minval + input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) + param:copy(torch.rand(param:nElement()):mul(inrange):add(minval)) + local params_bprop = nn.Jacobian.backwardUpdate(module, input, param) + local params_fprop = nn.Jacobian.forwardUpdate(module, input, param, perturbation) + + local error = params_fprop - params_bprop + return error:abs():max() +end + +function nn.Jacobian.testDiagHessian(module, input, gradParamName, diagHessianParamName, minval, maxval) + -- Compute the diagonal Hessian elements for the same function in two different ways, + -- then compare the results and return the difference. + + minval = minval or -2 + maxval = maxval or 2 + local inrange = maxval - minval + input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) + module:initDiagHessianParameters() + local h_bprop = nn.Jacobian.backwardDiagHessian(module, input, diagHessianParamName) + local h_linearmodule = nn.Jacobian.linearModuleDiagHessian(module, input, gradParamName) + local error = h_bprop - h_linearmodule + return error:abs():max() +end + +function nn.Jacobian.testDiagHessianInput(module, input, minval, maxval) + return nn.Jacobian.testDiagHessian(module, input, 'gradInput', 'diagHessianInput', minval, maxval) +end + +function nn.Jacobian.testDiagHessianWeight(module, input, minval, maxval) + return nn.Jacobian.testDiagHessian(module, input, 'gradWeight', 'diagHessianWeight', minval, maxval) +end + +function nn.Jacobian.testDiagHessianBias(module, input, minval, maxval) + return nn.Jacobian.testDiagHessian(module, input, 'gradBias', 'diagHessianBias', minval, maxval) +end + +function nn.Jacobian.testIO(module,input, minval, maxval) + minval = minval or -2 + maxval = maxval or 2 + local inrange = maxval - minval + local inputclone = input:clone() + + -- run module + module:forward(input) + local go = module.output:clone():copy(torch.rand(module.output:nElement()):mul(inrange):add(minval)) + local goclone = go:clone() + module:zeroGradParameters() + module:updateGradInput(input,go) + module:accGradParameters(input,go) + + local fo = module.output:clone() + local bo = module.gradInput:clone() + + -- write module + local filename = os.tmpname() + local f = torch.DiskFile(filename, 'w'):binary() + -- call clearState and check that it returns itself + assert(module == module:clearState(),'clearState did not return self') + f:writeObject(module) + f:close() + -- read module + local m = torch.DiskFile(filename):binary():readObject() + m:forward(inputclone) + m:zeroGradParameters() + m:updateGradInput(inputclone,goclone) + m:accGradParameters(inputclone,goclone) + -- cleanup + os.remove(filename) + + local fo2 = m.output:clone() + local bo2 = m.gradInput:clone() + + local errf = fo - fo2 + local errb = bo - bo2 + return errf:abs():max(), errb:numel() == 0 and 0 or errb:abs():max() +end + +function nn.Jacobian.testAllUpdate(module, input, weight, gradWeight) + local gradOutput + local lr = torch.uniform(0.1, 1) + local errors = {} + + -- accGradParameters + local maccgp = module:clone() + local weightc = maccgp[weight]:clone() + maccgp:forward(input) + gradOutput = torch.rand(maccgp.output:size()) + maccgp:zeroGradParameters() + maccgp:updateGradInput(input, gradOutput) + maccgp:accGradParameters(input, gradOutput) + maccgp:updateParameters(lr) + errors["accGradParameters"] = (weightc-maccgp[gradWeight]*lr-maccgp[weight]):norm() + + -- accUpdateGradParameters + local maccugp = module:clone() + maccugp:forward(input) + maccugp:updateGradInput(input, gradOutput) + maccugp:accUpdateGradParameters(input, gradOutput, lr) + errors["accUpdateGradParameters"] = (maccugp[weight]-maccgp[weight]):norm() + + -- shared, accGradParameters + local macsh1 = module:clone() + local macsh2 = module:clone() + macsh2:share(macsh1, weight) + macsh1:forward(input) + macsh2:forward(input) + macsh1:zeroGradParameters() + macsh2:zeroGradParameters() + macsh1:updateGradInput(input, gradOutput) + macsh2:updateGradInput(input, gradOutput) + macsh1:accGradParameters(input, gradOutput) + macsh2:accGradParameters(input, gradOutput) + macsh1:updateParameters(lr) + macsh2:updateParameters(lr) + local err = (weightc-maccgp[gradWeight]*(lr*2)-macsh1[weight]):norm() + err = err + (weightc-maccgp[gradWeight]*(lr*2)-macsh2[weight]):norm() + errors["accGradParameters [shared]"] = err + + -- shared, accUpdateGradParameters + local macshu1 = module:clone() + local macshu2 = module:clone() + macshu2:share(macshu1, weight) + macshu1:forward(input) + macshu2:forward(input) + macshu1:updateGradInput(input, gradOutput) + macshu2:updateGradInput(input, gradOutput) + macshu1:accUpdateGradParameters(input, gradOutput, lr) + macshu2:accUpdateGradParameters(input, gradOutput, lr) + err = (weightc-maccgp[gradWeight]*(lr*2)-macshu1[weight]):norm() + err = err + (weightc-maccgp[gradWeight]*(lr*2)-macshu2[weight]):norm() + errors["accUpdateGradParameters [shared]"] = err + + return errors +end |