You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConcatTable.lua 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Container')
  2. function ConcatTable:__init()
  3. parent.__init(self)
  4. self.modules = {}
  5. self.output = {}
  6. end
  7. function ConcatTable:updateOutput(input)
  8. for i=1,#self.modules do
  9. self.output[i] = self:rethrowErrors(self.modules[i], i, 'updateOutput', input)
  10. end
  11. return self.output
  12. end
  13. local function retable(t1, t2, f)
  14. for k, v in ipairs(t2) do
  15. if (torch.type(v) == "table") then
  16. t1[k] = retable(t1[k] or {}, t2[k], f)
  17. else
  18. f(t1, k, v)
  19. end
  20. end
  21. for i=#t2+1, #t1 do
  22. t1[i] = nil
  23. end
  24. return t1
  25. end
  26. local function backward(self, method, input, gradOutput, scale)
  27. local isTable = torch.type(input) == 'table'
  28. local wasTable = torch.type(self.gradInput) == 'table'
  29. if isTable then
  30. for i,module in ipairs(self.modules) do
  31. local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale)
  32. if torch.type(currentGradInput) ~= 'table' then
  33. error"currentGradInput is not a table!"
  34. end
  35. if #input ~= #currentGradInput then
  36. error("table size mismatch: "..#input.." ~= "..#currentGradInput)
  37. end
  38. if i == 1 then
  39. self.gradInput = wasTable and self.gradInput or {}
  40. retable(self.gradInput, currentGradInput,
  41. function(t, k, v)
  42. t[k] = t[k] or v:clone()
  43. t[k]:resize(v:size())
  44. t[k]:copy(v)
  45. end
  46. )
  47. else
  48. retable(self.gradInput, currentGradInput,
  49. function(t, k, v)
  50. if t[k] then
  51. t[k]:add(v)
  52. else
  53. t[k] = v:clone()
  54. end
  55. end
  56. )
  57. end
  58. end
  59. else
  60. self.gradInput = (not wasTable) and self.gradInput or input:clone()
  61. for i,module in ipairs(self.modules) do
  62. local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale)
  63. if i == 1 then
  64. self.gradInput:resize(currentGradInput:size()):copy(currentGradInput)
  65. else
  66. self.gradInput:add(currentGradInput)
  67. end
  68. end
  69. end
  70. return self.gradInput
  71. end
  72. function ConcatTable:updateGradInput(input, gradOutput)
  73. return backward(self, 'updateGradInput', input, gradOutput)
  74. end
  75. function ConcatTable:backward(input, gradOutput, scale)
  76. return backward(self, 'backward', input, gradOutput, scale)
  77. end
  78. function ConcatTable:accGradParameters(input, gradOutput, scale)
  79. scale = scale or 1
  80. for i,module in ipairs(self.modules) do
  81. self:rethrowErrors(module, i, 'accGradParameters', input, gradOutput[i], scale)
  82. end
  83. end
  84. function ConcatTable:accUpdateGradParameters(input, gradOutput, lr)
  85. for i,module in ipairs(self.modules) do
  86. self:rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutput[i], lr)
  87. end
  88. end
  89. function ConcatTable:__tostring__()
  90. local tab = ' '
  91. local line = '\n'
  92. local next = ' |`-> '
  93. local lastNext = ' `-> '
  94. local ext = ' | '
  95. local extlast = ' '
  96. local last = ' ... -> '
  97. local str = torch.type(self)
  98. str = str .. ' {' .. line .. tab .. 'input'
  99. for i=1,#self.modules do
  100. if i == #self.modules then
  101. str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
  102. else
  103. str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
  104. end
  105. end
  106. str = str .. line .. tab .. last .. 'output'
  107. str = str .. line .. '}'
  108. return str
  109. end