Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. local CReLU, parent = torch.class('nn.CReLU', 'nn.Sequential')
  2. -- Implements the CReLU activation function as described by
  3. -- W. Shang et al. in "Understanding and Improving Convolutional Neural Networks
  4. -- via Concatenated Rectified Linear Units"
  5. function CReLU:__init(nInputDims, inplace)
  6. parent.__init(self)
  7. self.nInputDims = nInputDims
  8. self.inplace = inplace or false
  9. local concatTable = nn.ConcatTable()
  10. concatTable:add(nn.Identity())
  11. concatTable:add(nn.MulConstant(-1))
  12. self:add(concatTable)
  13. self:add(nn.JoinTable(2))
  14. self:add(nn.ReLU(self.inplace))
  15. end
  16. function CReLU:updateOutput(input)
  17. local input_
  18. local batched = input:dim() == (self.nInputDims + 1)
  19. if not batched then
  20. input_ = input:view(1, -1)
  21. else
  22. input_ = input:view(input:size(1), -1)
  23. end
  24. parent.updateOutput(self, input_)
  25. local osize = input:size()
  26. if not batched then
  27. osize[1] = osize[1] * 2
  28. else
  29. osize[2] = osize[2] * 2
  30. end
  31. self.output:resize(osize)
  32. return self.output
  33. end
  34. function CReLU:backward(input, gradOutput)
  35. return self:updateGradInput(input, gradOutput)
  36. end
  37. function CReLU:updateGradInput(input, gradOutput)
  38. local batched = input:dim() == (self.nInputDims + 1)
  39. if not batched then
  40. parent.updateGradInput(self, input:view(1, -1), gradOutput:view(1, -1))
  41. else
  42. parent.updateGradInput(self, input:view(input:size(1), -1),
  43. gradOutput:view(input:size(1), -1))
  44. end
  45. self.gradInput:resizeAs(input)
  46. return self.gradInput
  47. end
  48. function CReLU:__tostring__()
  49. return "CReLU()"
  50. end