blob: 5e6ccb62475cc84178b9d732bb5c56ab6ac839f6 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
local Identity, _ = torch.class('nn.Identity', 'nn.Module')
function Identity:updateOutput(input)
self.output = input
return self.output
end
function Identity:updateGradInput(input, gradOutput)
self.gradInput = gradOutput
return self.gradInput
end
function Identity:clearState()
-- don't call set because it might reset referenced tensors
local function clear(f)
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
end
end
end
clear('output')
clear('gradInput')
return self
end
|