You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CAddTensorTable.lua 1.3KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module')
  2. function CAddTensorTable:__init()
  3. parent.__init(self)
  4. self.gradInput = {}
  5. end
  6. -- input is a table with 2 entries. input[1] is the vector to be added.
  7. -- input[2] is the table to which we add the vector
  8. function CAddTensorTable:updateOutput(input)
  9. local currentOutput = {}
  10. for i=1,#input[2] do
  11. currentOutput[i] = currentOutput[i] or input[1].new()
  12. currentOutput[i]:resizeAs(input[1])
  13. currentOutput[i]:copy(input[2][i])
  14. currentOutput[i]:add(input[1])
  15. end
  16. for i = #input[2]+1, #currentOutput do
  17. currentOutput[i] = nil
  18. end
  19. self.output = currentOutput
  20. return self.output
  21. end
  22. function CAddTensorTable:updateGradInput(input, gradOutput)
  23. self.gradInput[1] = self.gradInput[1] or input[1].new()
  24. self.gradInput[1]:resizeAs(input[1])
  25. self.gradInput[1]:copy(gradOutput[1])
  26. for i=2, #input[2] do
  27. self.gradInput[1]:add(gradOutput[i])
  28. end
  29. self.gradInput[2] = self.gradInput[2] or {}
  30. for i=1,#input[2] do
  31. self.gradInput[2][i] = self.gradInput[2][i] or input[1].new()
  32. self.gradInput[2][i]:resizeAs(input[1])
  33. self.gradInput[2][i]:copy(gradOutput[i])
  34. end
  35. for i=#input[2]+1, #self.gradInput[2] do
  36. self.gradInput[2][i] = nil
  37. end
  38. return self.gradInput
  39. end