123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- local CMul, parent = torch.class('nn.CMul', 'nn.Module')
-
- function CMul:__init(...)
- parent.__init(self)
-
- local arg = {...}
-
- self.size = torch.LongStorage()
- local n = #arg
- if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then
- self.size:resize(#arg[1]):copy(arg[1])
- else
- self.size:resize(n)
- for i=1,n do
- self.size[i] = arg[i]
- end
- end
-
- self.weight = torch.Tensor(self.size)
- self.gradWeight = torch.Tensor(self.size)
-
- self.output:resize(self.size)
-
- self:reset()
- end
-
- function CMul:reset(stdv)
- if stdv then
- stdv = stdv * math.sqrt(3)
- else
- stdv = 1./math.sqrt(self.weight:nElement())
- end
- self.weight:uniform(-stdv,stdv)
- end
-
- function CMul:updateOutput(input)
- -- lazy-initialize
- self._output = self._output or input.new()
- self._weight = self._weight or input.new()
- self._expand = self._expand or input.new()
- self._repeat = self._repeat or input.new()
-
- self.output:resizeAs(input):copy(input)
- if input:nElement() == self.weight:nElement() then
- self._output:view(self.output, -1)
- self._weight:view(self.weight, -1)
-
- self._output:cmul(self._weight)
- else
- if self.weight:dim() == input:dim() then
- self._output:set(self.output)
- self._weight:set(self.weight)
- else
- local batchSize = input:size(1)
- self._output:view(self.output, batchSize, -1)
- self._weight:view(self.weight, 1, -1)
- end
-
- self._expand:expandAs(self._weight, self._output)
-
- if torch.type(input) == 'torch.CudaTensor' then
- self._repeat:resizeAs(self._expand):copy(self._expand)
- self._output:cmul(self._repeat)
- else
- self._output:cmul(self._expand)
- end
- end
-
- return self.output
- end
-
- function CMul:updateGradInput(input, gradOutput)
- if not self.gradInput then
- return
- end
-
- self._gradOutput = self._gradOutput or input.new()
- self._gradInput = self._gradInput or input.new()
-
- self.gradInput:resizeAs(input):zero()
- if self.weight:nElement() == gradOutput:nElement() then
- self.gradInput:addcmul(1, self.weight, gradOutput)
- else
- if self.weight:dim() == input:dim() then
- nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size())
- nn.utils.contiguousView(self._gradInput, self.gradInput, self.gradInput:size())
- self._weight:set(self.weight)
- else
- local batchSize = input:size(1)
- nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1)
- nn.utils.contiguousView(self._gradInput, self.gradInput, batchSize, -1)
- self._weight:view(self.weight, 1, -1)
- end
-
- self._expand:expandAs(self._weight, self._gradOutput)
-
- if torch.type(input) == 'torch.CudaTensor' then
- self._repeat:resizeAs(self._expand):copy(self._expand)
- self._gradInput:addcmul(1, self._repeat, self._gradOutput)
- else
- self._gradInput:addcmul(1, self._expand, self._gradOutput)
- end
- end
-
- return self.gradInput
- end
-
- function CMul:accGradParameters(input, gradOutput, scale)
- scale = scale or 1
-
- self._input = self._input or input.new()
- self._gradWeight = self._gradWeight or input.new()
- self._sum = self._sum or input.new()
-
- if self.weight:nElement() == gradOutput:nElement() then
- self.gradWeight:addcmul(scale, input, gradOutput)
- else
- if self.weight:dim() == input:dim() then
- nn.utils.contiguousView(self._input, input, input:size())
- nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size())
- self._gradWeight:set(self.gradWeight)
-
- self._repeat:cmul(self._input, self._gradOutput)
- local sumInto = self._sum
- local sumFrom = self._repeat
- for i=1,self.weight:dim() do
- if self.weight:size(i) ~= input:size(i) then
- sumInto:sum(sumFrom, i)
- sumInto = sumFrom
- sumFrom = sumFrom == self._repeat and self._sum or self._repeat
- end
- end
- self._gradWeight:add(scale, sumFrom)
- else
- local batchSize = input:size(1)
- nn.utils.contiguousView(self._input, input, batchSize, -1)
- nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1)
- self._gradWeight:view(self.gradWeight, 1, -1)
-
- self._repeat:cmul(self._input, self._gradOutput)
- self._sum:sum(self._repeat, 1)
- self._gradWeight:add(scale, self._sum)
- end
-
- end
- end
-
- function CMul:type(type, tensorCache)
- if type then
- self:clearState()
- end
- return parent.type(self, type, tensorCache)
- end
-
- function CMul:clearState()
- nn.utils.clear(self, {
- '_input',
- '_output',
- '_weight',
- '_gradWeight',
- '_expand',
- '_repeat',
- '_sum',
- })
- return parent.clearState(self)
- end
|