You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Decorator.lua 1.5KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. local Decorator, parent = torch.class("nn.Decorator", "nn.Container")
  2. function Decorator:__init(module)
  3. parent.__init(self)
  4. -- so that it can be handled like a Container
  5. self.modules[1] = module
  6. end
  7. function Decorator:updateOutput(input)
  8. self.output = self.modules[1]:updateOutput(input)
  9. return self.output
  10. end
  11. function Decorator:updateGradInput(input, gradOutput)
  12. self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)
  13. return self.gradInput
  14. end
  15. function Decorator:accGradParameters(input, gradOutput, scale)
  16. self.modules[1]:accGradParameters(input, gradOutput, scale)
  17. end
  18. function Decorator:accUpdateGradParameters(input, gradOutput, lr)
  19. self.modules[1]:accUpdateGradParameters(input, gradOutput, lr)
  20. end
  21. function Decorator:sharedAccUpdateGradParameters(input, gradOutput, lr)
  22. self.modules[1]:sharedAccUpdateGradParameters(input, gradOutput, lr)
  23. end
  24. function Decorator:__tostring__()
  25. if self.modules[1].__tostring__ then
  26. return torch.type(self) .. ' @ ' .. self.modules[1]:__tostring__()
  27. else
  28. return torch.type(self) .. ' @ ' .. torch.type(self.modules[1])
  29. end
  30. end
  31. -- useful for multiple-inheritance
  32. function Decorator.decorate(class)
  33. class.updateOutput = nn.Decorator.updateOutput
  34. class.updateGradInput = nn.Decorator.updateGradInput
  35. class.accGradParameters = nn.Decorator.accGradParameters
  36. class.accUpdateGradParameters = nn.Decorator.accUpdateGradParameters
  37. class.sharedAccUpdateGradParameters = nn.Decorator.sharedAccUpdateGradParameters
  38. class.__tostring__ = nn.Decorator.__tostring__
  39. end