summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/GPU.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/GPU.lua')
-rw-r--r--contrib/lua-torch/nn/GPU.lua273
1 files changed, 273 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/GPU.lua b/contrib/lua-torch/nn/GPU.lua
new file mode 100644
index 000000000..758618d8b
--- /dev/null
+++ b/contrib/lua-torch/nn/GPU.lua
@@ -0,0 +1,273 @@
+------------------------------------------------------------------------
+--[[ GPU ]]--
+-- Decorates a module such that its parameters are
+-- hosted on a specified GPU device.
+-- The operations are also executed on that device.
+-- Arguments input and gradOutput are converted to the specified device
+-- before being fed to the decorated module.
+-- Returned output is on the specified outdevice (defaults to device).
+-- Returned gradInput is allocated on the same device as the input.
+-- The unit test is located in cunn.
+------------------------------------------------------------------------
+local GPU, parent = torch.class("nn.GPU", "nn.Container")
+
+function GPU:__init(module, device, outdevice)
+ parent.__init(self)
+ assert(torch.type(device) == 'number')
+ self.device = device
+ self.outdevice = outdevice or device
+
+ assert(torch.isTypeOf(module, 'nn.Module'))
+ self.modules[1] = module
+
+ if module:type():find('torch%.Cuda.*Tensor') then
+ self:type(module:type())
+ end
+end
+
+function GPU.recursiveModuleDevice(obj, device)
+ if type(obj) == 'table' and not torch.isTypeOf(obj, 'nn.GPU') and not obj.__noGPU__ then
+ for k,v in pairs(obj) do
+ obj[k] = GPU.recursiveModuleDevice(v, device)
+ end
+ elseif torch.type(obj):match('torch.Cuda.*Tensor') then
+ if obj:getDevice() ~= device then
+ obj = obj:clone() -- this will reallocate it to device
+ local newdevice = obj:getDevice()
+ -- when nElement() == 0 newdevice is 0
+ assert(newdevice == device or newdevice == 0)
+ end
+ end
+ assert(obj ~= nil)
+ return obj
+end
+
+-- set the device of the decorated module
+function GPU:setDevice(device)
+ self.device = device or self.device
+
+ assert(self.modules[1])
+ self.modules[1] = cutorch.withDevice(self.device, function()
+ return self.recursiveModuleDevice(self.modules[1], self.device)
+ end)
+ return self
+end
+
+-- when proto is a device number, returns a dst that has device device for each element in src
+-- otherwise, if proto is a table/tensor, makes sure dst is a identical to src, yet on the same device as proto
+function GPU.recursiveSetDevice(dst, src, proto)
+ local device, prototable
+ if torch.isTensor(proto) then
+ device = proto:getDevice()
+ elseif torch.type(proto) == 'number' then
+ device = proto
+ elseif torch.type(proto) == 'table' then
+ prototable = true
+ else
+ error"Expecting number, table or tensor for arg 3 (proto)"
+ end
+ if torch.type(src) == 'table' then
+ dst = torch.type(dst) == 'table' and dst or {}
+ for k,v in ipairs(src) do
+ dst[k] = GPU.recursiveSetDevice(dst[k], v, prototable and proto[k] or device)
+ end
+ for k=#src+1,#dst do
+ dst[k] = nil
+ end
+ elseif torch.type(src):match('torch.Cuda.*Tensor') and src:getDevice() ~= device and src:getDevice() ~= 0 then
+ if not (torch.type(dst):match('torch.Cuda.*Tensor') and dst:getDevice() == device) then
+ dst = src.new()
+ end
+ cutorch.withDevice(device, function() dst:resizeAs(src):copy(src) end)
+ else
+ dst = src
+ end
+ return dst
+end
+
+function GPU:updateOutput(input)
+ if self._type:find('torch%.Cuda.*Tensor') then
+ self._input = self.recursiveSetDevice(self._input, input, self.device)
+
+ local output = cutorch.withDevice(self.device, function()
+ return self.modules[1]:updateOutput(self._input)
+ end)
+
+ if self.device ~= self.outdevice then
+ self.output = self.recursiveSetDevice(self.output, output, self.outdevice)
+ else
+ self.output = output
+ end
+ else
+ self.output = self.modules[1]:updateOutput(input)
+ end
+
+ return self.output
+end
+
+function GPU:updateGradInput(input, gradOutput)
+ if self._type:find('torch%.Cuda.*Tensor') then
+ self._gradOutput = self.recursiveSetDevice(self._gradOutput, gradOutput, self.device)
+
+ local gradInput = cutorch.withDevice(self.device, function()
+ return self.modules[1]:updateGradInput(self._input, self._gradOutput)
+ end)
+
+ self.gradInput = self.recursiveSetDevice(self.gradInput, gradInput, input)
+ else
+ self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)
+ end
+
+ return self.gradInput
+end
+
+function GPU:accGradParameters(input, gradOutput, scale)
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function()
+ self.modules[1]:accGradParameters(self._input, self._gradOutput, scale)
+ end)
+ else
+ self.modules[1]:accGradParameters(input, gradOutput, scale)
+ end
+end
+
+function GPU:apply(callback)
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.apply(self, callback) end)
+ else
+ parent.apply(self, callback)
+ end
+end
+
+function GPU:type(type, typecache)
+ if type and type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.type(self, type, typecache) end)
+ self:setDevice()
+ else
+ self.output = nil
+ self.gradInput = nil
+ self._input = nil
+ self._gradOutput = nil
+ parent.type(self, type, typecache)
+ end
+ return self
+end
+
+function GPU:clearState()
+ nn.utils.clear(self, 'output', 'gradInput')
+ self._input = nil
+ self._gradOutput = nil
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.clearState(self) end)
+ else
+ parent.clearState(self)
+ end
+end
+
+function GPU:zeroGradParameters()
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.zeroGradParameters(self) end)
+ else
+ parent.zeroGradParameters(self)
+ end
+end
+
+function GPU:updateParameters(lr)
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.updateParameters(self, lr) end)
+ else
+ parent.updateParameters(self, lr)
+ end
+end
+
+function GPU:training()
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.training(self) end)
+ else
+ parent.training(self)
+ end
+end
+
+function GPU:evaluate()
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.evaluate(self) end)
+ else
+ parent.evaluate(self)
+ end
+end
+
+function GPU:share(mlp, ...)
+ local args = {...}
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.share(self, mlp, unpack(args)) end)
+ else
+ parent.share(self, mlp, unpack(args))
+ end
+ return self
+end
+
+function GPU:reset(...)
+ local args = {...}
+ if self._type:find('torch%.Cuda.*Tensor') then
+ cutorch.withDevice(self.device, function() parent.reset(self, unpack(args)) end)
+ else
+ parent.reset(self, unpack(args))
+ end
+ return self
+end
+
+function GPU:clone(...)
+ local args = {...}
+ if self._type:find('torch%.Cuda.*Tensor') then
+ return cutorch.withDevice(self.device, function() parent.clone(self, unpack(args)) end)
+ else
+ return parent.clone(self, unpack(args))
+ end
+end
+
+function GPU:write(file)
+ -- Write all values in the object as a table.
+ local object = {}
+ for k, v in pairs(self) do
+ object[k] = v
+ end
+ local header = {self._type, self.device}
+ file:writeObject(header)
+ file:writeObject(object)
+end
+
+function GPU:read(file)
+ local header = file:readObject()
+ local object
+ if header[1] and header[1]:find('torch%.Cuda.*Tensor') then
+ local device = header[2]
+ if device > cutorch.getDeviceCount() then
+ print"Warning : model was saved with more devices than available on current host."
+ print"Attempting to load module onto device 1"
+ device = 1
+ end
+ object = cutorch.withDevice(device, function() return file:readObject() end)
+ else
+ object = file:readObject()
+ end
+
+ for k, v in pairs(object) do
+ self[k] = v
+ end
+end
+
+function GPU:__tostring__()
+ if self.modules[1].__tostring__ then
+ return torch.type(self) .. '(' .. self.device ..') @ ' .. self.modules[1]:__tostring__()
+ else
+ return torch.type(self) .. '(' .. self.device ..') @ ' .. torch.type(self.modules[1])
+ end
+end
+
+function GPU:accUpdateGradParameters(input, gradOutput, lr)
+ error("Not Implemented for "..torch.type(self))
+end
+
+function GPU:sharedAccUpdateGradParameters(input, gradOutput, lr)
+ error("Not Implemented for "..torch.type(self))
+end