You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

AddConstant.lua 1.5KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module')
  2. function AddConstant:__init(constant_scalar,ip)
  3. parent.__init(self)
  4. self.constant_scalar = constant_scalar
  5. -- default for inplace is false
  6. self.inplace = ip or false
  7. if (ip and type(ip) ~= 'boolean') then
  8. error('in-place flag must be boolean')
  9. end
  10. end
  11. function AddConstant:updateOutput(input)
  12. assert(type(self.constant_scalar) == 'number' or
  13. (torch.isTensor(self.constant_scalar) and input:nDimension() <= 2 and
  14. input:size(input:nDimension()) == self.constant_scalar:size(1)),
  15. 'input is not scalar or doesn\'t match with the dimension of constant!')
  16. local tmp
  17. if torch.isTensor(self.constant_scalar) and input:nDimension() == 2 then
  18. local nOutput = self.constant_scalar:size(1)
  19. tmp = self.constant_scalar.new()
  20. tmp:resize(1,nOutput)
  21. tmp:copy(self.constant_scalar)
  22. tmp = tmp:expand(input:size(1),nOutput)
  23. else
  24. tmp = self.constant_scalar
  25. end
  26. if self.inplace then
  27. input:add(tmp)
  28. self.output:set(input)
  29. else
  30. self.output:resizeAs(input)
  31. self.output:copy(input)
  32. self.output:add(tmp)
  33. end
  34. return self.output
  35. end
  36. function AddConstant:updateGradInput(input, gradOutput)
  37. if self.inplace then
  38. self.gradInput:set(gradOutput)
  39. -- restore previous input value
  40. input:add(-self.constant_scalar)
  41. else
  42. self.gradInput:resizeAs(gradOutput)
  43. self.gradInput:copy(gradOutput)
  44. end
  45. return self.gradInput
  46. end