You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Copy.lua 1.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. local Copy, parent = torch.class('nn.Copy', 'nn.Module')
  2. function Copy:__init(intype, outtype, forceCopy, dontCast)
  3. intype = intype or torch.Tensor.__typename
  4. outtype = outtype or torch.Tensor.__typename
  5. self.dontCast = dontCast
  6. parent.__init(self)
  7. self.gradInput = torch.getmetatable(intype).new()
  8. self.output = torch.getmetatable(outtype).new()
  9. if (not forceCopy) and intype == outtype then
  10. self.updateOutput = function(self, input)
  11. self.output:set(input)
  12. return input
  13. end
  14. self.updateGradInput = function(self, input, gradOutput)
  15. self.gradInput:set(gradOutput)
  16. return gradOutput
  17. end
  18. end
  19. end
  20. function Copy:updateOutput(input)
  21. self.output:resize(input:size()):copy(input)
  22. return self.output
  23. end
  24. function Copy:updateGradInput(input, gradOutput)
  25. self.gradInput:resize(gradOutput:size()):copy(gradOutput)
  26. return self.gradInput
  27. end
  28. function Copy:type(type, tensorCache)
  29. if type and self.dontCast then
  30. return self
  31. end
  32. return parent.type(self, type, tensorCache)
  33. end