aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/nn/ZeroGrad.lua
blob: 7c941ce1c658f4a0c6f262c5406f37b26498dd26 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
local ZeroGrad, parent = torch.class('nn.ZeroGrad', 'nn.Module')

function ZeroGrad:updateOutput(input)
   self.output:set(input)
   return self.output
end

-- the gradient is simply zeroed.
-- useful when you don't want to backpropgate through certain paths.
function ZeroGrad:updateGradInput(input, gradOutput)
   self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
   self.gradInput = nn.utils.recursiveFill(self.gradInput, 0)
   return self.gradInput
end