aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/HardShrink.lua
blob: 85ff5909cb1c7d5ff5f2ca9eeb7de92ba389a71a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
local HardShrink, parent = torch.class('nn.HardShrink', 'nn.Module')

function HardShrink:__init(lam)
   parent.__init(self)
   self.lambda = lam or 0.5
end

function HardShrink:updateOutput(input)
   input.THNN.HardShrink_updateOutput(
      input:cdata(),
      self.output:cdata(),
      self.lambda
   )
   return self.output
end

function HardShrink:updateGradInput(input, gradOutput)
   input.THNN.HardShrink_updateGradInput(
      input:cdata(),
      gradOutput:cdata(),
      self.gradInput:cdata(),
      self.lambda
   )
   return self.gradInput
end