diff options
Diffstat (limited to 'contrib/lua-torch/torch7/init.lua')
-rw-r--r-- | contrib/lua-torch/torch7/init.lua | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/contrib/lua-torch/torch7/init.lua b/contrib/lua-torch/torch7/init.lua new file mode 100644 index 000000000..0f3cfbb1e --- /dev/null +++ b/contrib/lua-torch/torch7/init.lua @@ -0,0 +1,192 @@ +-- We are using paths.require to appease mkl + +-- Make this work with LuaJIT in Lua 5.2 compatibility mode, which +-- renames string.gfind (already deprecated in 5.1) +if not string.gfind then + string.gfind = string.gmatch +end +if not table.unpack then + table.unpack = unpack +end + +require "paths" +paths.require "libtorch" + +-- Keep track of all thread local variables torch. +-- if a Lua VM is passed to another thread thread local +-- variables need to be updated. +function torch.updatethreadlocals() + torch.updateerrorhandlers() + local tracking = torch._heaptracking + if tracking == nil then tracking = false end + torch.setheaptracking(tracking) +end + +--- package stuff +function torch.packageLuaPath(name) + if not name then + local ret = string.match(torch.packageLuaPath('torch'), '(.*)/') + if not ret then --windows? + ret = string.match(torch.packageLuaPath('torch'), '(.*)\\') + end + return ret + end + for path in string.gmatch(package.path, "[^;]+") do + path = string.gsub(path, "%?", name) + local f = io.open(path) + if f then + f:close() + local ret = string.match(path, "(.*)/") + if not ret then --windows? + ret = string.match(path, "(.*)\\") + end + return ret + end + end +end + +local function include(file, depth) + paths.dofile(file, 3 + (depth or 0)) +end +rawset(_G, 'include', include) + +function torch.include(package, file) + dofile(torch.packageLuaPath(package) .. '/' .. file) +end + +function torch.class(...) + local tname, parenttname, module + if select('#', ...) == 3 + and type(select(1, ...)) == 'string' + and type(select(2, ...)) == 'string' + and type(select(3, ...)) == 'table' + then + tname = select(1, ...) + parenttname = select(2, ...) + module = select(3, ...) + elseif select('#', ...) == 2 + and type(select(1, ...)) == 'string' + and type(select(2, ...)) == 'string' + then + tname = select(1, ...) + parenttname = select(2, ...) + elseif select('#', ...) == 2 + and type(select(1, ...)) == 'string' + and type(select(2, ...)) == 'table' + then + tname = select(1, ...) + module = select(2, ...) + elseif select('#', ...) == 1 + and type(select(1, ...)) == 'string' + then + tname = select(1, ...) + else + error('<class name> [<parent class name>] [<module table>] expected') + end + + local function constructor(...) + local self = {} + torch.setmetatable(self, tname) + if self.__init then + self:__init(...) + end + return self + end + + local function factory() + local self = {} + torch.setmetatable(self, tname) + return self + end + + local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory, module) + local mpt + if parenttname then + mpt = torch.getmetatable(parenttname) + end + return mt, mpt +end + +function torch.setdefaulttensortype(typename) + assert(type(typename) == 'string', 'string expected') + if torch.getconstructortable(typename) then + torch.Tensor = torch.getconstructortable(typename) + torch.Storage = torch.getconstructortable(torch.typename(torch.Tensor(1):storage())) + else + error(string.format("<%s> is not a string describing a torch object", typename)) + end +end + +function torch.type(obj) + local class = torch.typename(obj) + if not class then + class = type(obj) + end + return class +end + +--[[ See if a given object is an instance of the provided torch class. ]] +function torch.isTypeOf(obj, typeSpec) + -- typeSpec can be provided as either a string, pattern, or the constructor. + -- If the constructor is used, we look in the __typename field of the + -- metatable to find a string to compare to. + if type(typeSpec) ~= 'string' then + typeSpec = getmetatable(typeSpec).__typename + assert(type(typeSpec) == 'string', + "type must be provided as [regexp] string, or factory") + end + + local mt = getmetatable(obj) + while mt do + if type(mt) == 'table' and mt.__typename then + local match = mt.__typename:match(typeSpec) + -- Require full match for non-pattern specs + if match and (match ~= typeSpec or match == mt.__typename) then + return true + end + end + mt = getmetatable(mt) + end + return false +end + +torch.setdefaulttensortype('torch.DoubleTensor') + +require('torch.Tensor') +require('torch.File') +require('torch.CmdLine') +require('torch.FFInterface') +require('torch.Tester') +require('torch.TestSuite') +require('torch.test') + +function torch.totable(obj) + if torch.isTensor(obj) or torch.isStorage(obj) then + return obj:totable() + else + error("obj must be a Storage or a Tensor") + end +end + +function torch.isTensor(obj) + local typename = torch.typename(obj) + if typename and typename:find('torch.*Tensor') then + return true + end + return false +end + +function torch.isStorage(obj) + local typename = torch.typename(obj) + if typename and typename:find('torch.*Storage') then + return true + end + return false +end +-- alias for convenience +torch.Tensor.isTensor = torch.isTensor + +-- remove this line to disable automatic heap-tracking for garbage collection +torch.setheaptracking(true) + +return torch |