-- This file contains tests for the QR decomposition functions in torch: -- torch.qr(), torch.geqrf() and torch.orgqr(). local torch = require 'torch' local tester = torch.Tester() local tests = torch.TestSuite() -- torch.qr() with result tensors given. local function qrInPlace(tensorFunc) return function(x) local q, r = tensorFunc(), tensorFunc() torch.qr(q, r, x:clone()) return q, r end end -- torch.qr() without result tensors given. local function qrReturned(tensorFunc) return function(x) return torch.qr(x:clone()) end end -- torch.geqrf() with result tensors given. local function geqrfInPlace(tensorFunc) return function(x) local result = tensorFunc() local tau = tensorFunc() local result_, tau_ = torch.geqrf(result, tau, x) assert(torch.pointer(result) == torch.pointer(result_), 'expected result, result_ same tensor') assert(torch.pointer(tau) == torch.pointer(tau_), 'expected tau, tau_ same tensor') return result_, tau_ end end -- torch.orgqr() with result tensors given. local function orgqrInPlace(tensorFunc) return function(result, tau) local q = tensorFunc() local q_ = torch.orgqr(q, result, tau) assert(torch.pointer(q) == torch.pointer(q_), 'expected q, q_ same tensor') return q end end -- Test a custom QR routine that calls the LAPACK functions manually. local function qrManual(geqrfFunc, orgqrFunc) return function(x) local m = x:size(1) local n = x:size(2) local k = math.min(m, n) local result, tau = geqrfFunc(x) assert(result:size(1) == m) assert(result:size(2) == n) assert(tau:size(1) == k) local r = torch.triu(result:narrow(1, 1, k)) local q = orgqrFunc(result, tau) return q:narrow(2, 1, k), r end end -- Check that Q multiplied with a matrix with ormqr gives the correct result local function checkQM(testOpts, mat1, mat2) local q, r = torch.qr(mat1) local m, tau = torch.geqrf(mat1) local requiredPrecision = 1e-5 tester:assertTensorEq(torch.mm(q, mat2), torch.ormqr(m, tau, mat2), requiredPrecision) tester:assertTensorEq(torch.mm(mat2, q), torch.ormqr(m, tau, mat2, 'R'), requiredPrecision) tester:assertTensorEq(torch.mm(q:t(), mat2), torch.ormqr(m, tau, mat2, 'L', 'T'), requiredPrecision) tester:assertTensorEq(torch.mm(mat2, q:t()), torch.ormqr(m, tau, mat2, 'R', 'T'), requiredPrecision) end -- Check that the given `q`, `r` matrices are a valid QR decomposition of `a`. local function checkQR(testOpts, a, q, r) local qrFunc = testOpts.qr if not q then q, r = qrFunc(a) end local k = math.min(a:size(1), a:size(2)) tester:asserteq(q:size(1), a:size(1), "Bad size for q first dimension.") tester:asserteq(q:size(2), k, "Bad size for q second dimension.") tester:asserteq(r:size(1), k, "Bad size for r first dimension.") tester:asserteq(r:size(2), a:size(2), "Bad size for r second dimension.") tester:assertTensorEq(q:t() * q, torch.eye(q:size(2)):typeAs(testOpts.tensorFunc()), testOpts.precision, "Q was not orthogonal") tester:assertTensorEq(r, r:triu(), testOpts.precision, "R was not upper triangular") tester:assertTensorEq(q * r, a, testOpts.precision, "QR = A") end -- Do a QR decomposition of `a` and check that the result is valid and matches -- the given expected `q` and `r`. local function checkQRWithExpected(testOpts, a, expected_q, expected_r) local qrFunc = testOpts.qr -- Since the QR decomposition is unique only up to the signs of the rows of -- R, we must ensure these are positive before doing the comparison. local function canonicalize(q, r) local d = r:diag():sign():diag() return q * d, d * r end local q, r = qrFunc(a) local q_canon, r_canon = canonicalize(q, r) local expected_q_canon, expected_r_canon = canonicalize(expected_q, expected_r) tester:assertTensorEq(q_canon, expected_q_canon, testOpts.precision, "Q did not match expected") tester:assertTensorEq(r_canon, expected_r_canon, testOpts.precision, "R did not match expected") checkQR(testOpts, a, q, r) end -- Generate a separate test based on `func` for each of the possible -- combinations of tensor type (double or float) and QR function (torch.qr -- in-place, torch.qr, and manually calling the geqrf and orgqr from Lua -- (both in-place and not). -- -- The tests are added to the given `tests` table, with names generated by -- appending a unique string for the specific combination to `name`. -- -- If opts.doubleTensorOnly is true, then the FloatTensor versions of the test -- will be skipped. local function addTestVariations(tests, name, func, opts) opts = opts or {} local tensorTypes = { [torch.DoubleTensor] = 1e-12, [torch.FloatTensor] = 1e-5, } for tensorFunc, requiredPrecision in pairs(tensorTypes) do local qrFuncs = { ['inPlace'] = qrInPlace(tensorFunc), ['returned'] = qrReturned(tensorFunc), ['manualInPlace'] = qrManual(geqrfInPlace(tensorFunc), orgqrInPlace(tensorFunc)), ['manualReturned'] = qrManual(torch.geqrf, torch.orgqr) } for qrName, qrFunc in pairs(qrFuncs) do local testOpts = { tensorFunc=tensorFunc, precision=requiredPrecision, qr=qrFunc, } local tensorType = tensorFunc():type() local fullName = name .. "_" .. qrName .. "_" .. tensorType assert(not tests[fullName]) if tensorType == 'torch.DoubleTensor' or not opts.doubleTensorOnly then tests[fullName] = function() local state = torch.getRNGState() torch.manualSeed(1) func(testOpts) torch.setRNGState(state) end end end end end -- Decomposing a specific square matrix. addTestVariations(tests, 'qrSquare', function(testOpts) return function(testOpts) local tensorFunc = testOpts.tensorFunc local a = tensorFunc{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}} local expected_q = tensorFunc{ {-1.230914909793328e-01, 9.045340337332914e-01, 4.082482904638621e-01}, {-4.923659639173310e-01, 3.015113445777629e-01, -8.164965809277264e-01}, {-8.616404368553292e-01, -3.015113445777631e-01, 4.082482904638634e-01}, } local expected_r = tensorFunc{ {-8.124038404635959e+00, -9.601136296387955e+00, -1.107823418813995e+01}, { 0.000000000000000e+00, 9.045340337332926e-01, 1.809068067466585e+00}, { 0.000000000000000e+00, 0.000000000000000e+00, -8.881784197001252e-16}, } checkQRWithExpected(testOpts, a, expected_q, expected_r) end end, {doubleTensorOnly=true}) -- Decomposing a specific (wide) rectangular matrix. addTestVariations(tests, 'qrRectFat', function(testOpts) -- The matrix is chosen to be full-rank. local a = testOpts.tensorFunc{ {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 13} } local expected_q = testOpts.tensorFunc{ {-0.0966736489045663, 0.907737593658436 , 0.4082482904638653}, {-0.4833682445228317, 0.3157348151855452, -0.8164965809277254}, {-0.870062840141097 , -0.2762679632873518, 0.4082482904638621} } local expected_r = testOpts.tensorFunc{ { -1.0344080432788603e+01, -1.1794185166357092e+01, -1.3244289899925587e+01, -1.5564457473635180e+01}, { 0.0000000000000000e+00, 9.4720444555662542e-01, 1.8944088911132546e+00, 2.5653453733825331e+00}, { 0.0000000000000000e+00, 0.0000000000000000e+00, 1.5543122344752192e-15, 4.0824829046386757e-01} } checkQRWithExpected(testOpts, a, expected_q, expected_r) end, {doubleTensorOnly=true}) -- Decomposing a specific (thin) rectangular matrix. addTestVariations(tests, 'qrRectThin', function(testOpts) -- The matrix is chosen to be full-rank. local a = testOpts.tensorFunc{ { 1, 2, 3}, { 4, 5, 6}, { 7, 8, 9}, {10, 11, 13}, } local expected_q = testOpts.tensorFunc{ {-0.0776150525706334, -0.833052161400748 , 0.3651483716701106}, {-0.3104602102825332, -0.4512365874254053, -0.1825741858350556}, {-0.5433053679944331, -0.0694210134500621, -0.7302967433402217}, {-0.7761505257063329, 0.3123945605252804, 0.5477225575051663} } local expected_r = testOpts.tensorFunc{ {-12.8840987267251261, -14.5916298832790581, -17.0753115655393231}, { 0, -1.0413152017509357, -1.770235842976589 }, { 0, 0, 0.5477225575051664} } checkQRWithExpected(testOpts, a, expected_q, expected_r) end, {doubleTensorOnly=true}) -- Decomposing a sequence of medium-sized random matrices. addTestVariations(tests, 'randomMediumQR', function(testOpts) for x = 0, 10 do for y = 0, 10 do local m = math.pow(2, x) local n = math.pow(2, y) local x = torch.rand(m, n) checkQR(testOpts, x:typeAs(testOpts.tensorFunc())) end end end) -- Decomposing a sequence of small random matrices. addTestVariations(tests, 'randomSmallQR', function(testOpts) for m = 1, 40 do for n = 1, 40 do checkQR(testOpts, torch.rand(m, n):typeAs(testOpts.tensorFunc())) end end end) -- Decomposing a sequence of small matrices that are not contiguous in memory. addTestVariations(tests, 'randomNonContiguous', function(testOpts) for m = 2, 40 do for n = 2, 40 do local x = torch.rand(m, n):t() tester:assert(not x:isContiguous(), "x should not be contiguous") checkQR(testOpts, x:typeAs(testOpts.tensorFunc())) end end end) function tests.testQM() checkQM({}, torch.randn(10, 10), torch.randn(10, 10)) -- checkQM({}, torch.randn(20, 10), torch.randn(20, 20)) end tester:add(tests) tester:run()