blob: 7c941ce1c658f4a0c6f262c5406f37b26498dd26 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
local ZeroGrad, parent = torch.class('nn.ZeroGrad', 'nn.Module')
function ZeroGrad:updateOutput(input)
self.output:set(input)
return self.output
end
-- the gradient is simply zeroed.
-- useful when you don't want to backpropgate through certain paths.
function ZeroGrad:updateGradInput(input, gradOutput)
self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
self.gradInput = nn.utils.recursiveFill(self.gradInput, 0)
return self.gradInput
end
|