Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. local Add, parent = torch.class('nn.Add', 'nn.Module')
  2. function Add:__init(inputSize,scalar)
  3. parent.__init(self)
  4. local size = inputSize
  5. if scalar then size=1 end
  6. self.scalar = scalar
  7. self.bias = torch.Tensor(size)
  8. self.gradBias = torch.Tensor(size)
  9. self._ones = torch.Tensor{1}
  10. self:reset()
  11. end
  12. function Add:reset(stdv)
  13. if stdv then
  14. stdv = stdv * math.sqrt(3)
  15. else
  16. stdv = 1./math.sqrt(self.bias:size(1))
  17. end
  18. self.bias:uniform(-stdv, stdv)
  19. end
  20. function Add:updateOutput(input)
  21. self.output:resizeAs(input):copy(input)
  22. if self.scalar then
  23. self.output:add(self.bias[1]);
  24. else
  25. if input:isSameSizeAs(self.bias) then
  26. self.output:add(self.bias)
  27. else
  28. local batchSize = input:size(1)
  29. if self._ones:size(1) ~= batchSize then
  30. self._ones:resize(batchSize):fill(1)
  31. end
  32. local bias = self.bias:view(-1)
  33. local output = self.output:view(batchSize, -1)
  34. output:addr(1, self._ones, bias)
  35. end
  36. end
  37. return self.output
  38. end
  39. function Add:updateGradInput(input, gradOutput)
  40. if self.gradInput then
  41. self.gradInput:resizeAs(gradOutput):copy(gradOutput)
  42. return self.gradInput
  43. end
  44. end
  45. function Add:accGradParameters(input, gradOutput, scale)
  46. scale = scale or 1
  47. if self.gradBias:size(1) == 1 then
  48. self.gradBias[1] = self.gradBias[1] + scale*gradOutput:sum();
  49. else
  50. if input:isSameSizeAs(self.bias) then
  51. self.gradBias:add(scale, gradOutput)
  52. else
  53. local gradOutput = gradOutput:view(input:size(1), -1)
  54. self.gradBias:view(-1):addmv(scale, gradOutput:t(), self._ones)
  55. end
  56. end
  57. end