diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-07-16 16:39:35 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2017-07-16 16:39:35 +0100 |
commit | 57ad67a4b4560a936f9bc7efa6a1a3778a1372fa (patch) | |
tree | 179c5874aef6cc980d6fd3bee1fff1b9ec1c27fe /contrib/torch/torch7/Tester.lua | |
parent | 8d080cdc349ee281701ae185d2053314611875da (diff) | |
download | rspamd-57ad67a4b4560a936f9bc7efa6a1a3778a1372fa.tar.gz rspamd-57ad67a4b4560a936f9bc7efa6a1a3778a1372fa.zip |
[Feature] Import torch to Rspamd...
Diffstat (limited to 'contrib/torch/torch7/Tester.lua')
-rw-r--r-- | contrib/torch/torch7/Tester.lua | 879 |
1 files changed, 879 insertions, 0 deletions
diff --git a/contrib/torch/torch7/Tester.lua b/contrib/torch/torch7/Tester.lua new file mode 100644 index 000000000..6509413c2 --- /dev/null +++ b/contrib/torch/torch7/Tester.lua @@ -0,0 +1,879 @@ + +-- Lua 5.2 compatibility +local unpack = unpack or table.unpack + +local check = {} -- helper functions, defined at the bottom of the file + +local Tester = torch.class('torch.Tester') + +function Tester:__init() + self.errors = {} + self.tests = {} + self.warnings = {} + self._warningCount = {} + self.disabledTests = {} + self._currentTestName = '' + + -- To maintain backwards compatibility (at least for a short while), + -- disable exact dimension checking of tensors when :assertTensorEq is + -- called. Thus {{1}} == {1} when this flag is true. + -- + -- Note that other methods that suppose tensor checking (such as + -- :assertGeneralEq) ignore this flag, since previously they didn't + -- exist or support tensor equality checks at all, so there is no + -- old code that uses these functions and relies on the behaviour. + -- + -- Note also that if the dimension check fails with this flag is true, then + -- will show a warning. + self._assertTensorEqIgnoresDims = true +end + +function Tester:setEarlyAbort(earlyAbort) + self.earlyAbort = earlyAbort +end + +function Tester:setRethrowErrors(rethrow) + self.rethrow = rethrow +end + +function Tester:setSummaryOnly(summaryOnly) + self.summaryOnly = summaryOnly +end + +-- Add a success to the test. +function Tester:_success() + local name = self._currentTestName + self.assertionPass[name] = self.assertionPass[name] + 1 + return true +end + +function Tester:_addDebugInfo(message) + local ss = debug.traceback('tester', 3) or '' + ss = ss:match('.-\n([^\n]+\n[^\n]+)\n[^\n]+xpcall') or '' + local name = self._currentTestName + return (name ~= '' and name .. '\n' or '') .. message .. '\n' .. ss +end + +-- Add a failure to the test. +function Tester:_failure(message) + if self.rethrow then error(message, 2) end + local name = self._currentTestName + self.assertionFail[name] = self.assertionFail[name] + 1 + self.errors[#self.errors + 1] = self:_addDebugInfo(message) + return false +end + +-- Add a warning to the test +function Tester:_warning(message) + local name = self._currentTestName + self._warningCount[name] = (self._warningCount[name] or 0) + 1 + self.warnings[#self.warnings + 1] = self:_addDebugInfo(message) +end + +-- Call this during a test run with `condition = true` to log a success, or with +-- `condition = false` to log a failure (using `message`). +function Tester:_assert_sub(condition, message) + if condition then + return self:_success() + else + return self:_failure(message) + end +end + +local function getMessage(message, ...) + assert(next{...} == nil, "Unexpected arguments passed to test function") + if message then + assert(type(message) == 'string', 'message parameter must be a string') + if message ~= '' then + return message .. '\n' + end + end + return '' +end + +--[[ Historically, some test functions have accepted both a message and a +tolerance, and some just a message (e.g., assertTableEq). Now assertTableEq +accepts both a tolerance and a message, so allow the two arguments to be passed +in either order to maintain backwards compatibility (and more generally, +for convenience). (We still document the ordering as "tolerance, message" for +clarity.) This function also sanitizes them (ensures they are non-nil, etc). +]] +local function getToleranceAndMessage(defaultTolerance, ...) + local args = {...} + local message = nil + local tolerance = nil + for _, a in ipairs(args) do + if type(a) == 'string' then + if message then + error("Unexpected string argument; already have message", a) + end + message = a .. '\n' + elseif type(a) == 'number' then + if tolerance then + error("Unexpected number argument; already have tolerance", a) + end + tolerance = a + assert(tolerance >= 0, "tolerance cannot be negative") + else + error("Unrecognized argument; should be a tolerance or message", a) + end + end + message = message or '' + tolerance = tolerance or defaultTolerance + return tolerance, message +end + +function Tester:assert(condition, ...) + local message = getMessage(...) + if type(condition) ~= 'boolean' then + self:_warning(" :assert should only be used for boolean conditions. " + .. "To check for non-nil variables, do this explicitly: " + .. "Tester:assert(var ~= nil).") + end + return self:_assert_sub(condition, + string.format('%sBOOL violation condition=%s', + message, tostring(condition))) +end + +function Tester:assertGeneralEq(got, expected, ...) + return self:_eqOrNeq(got, expected, false, ...) +end + +function Tester:eq(got, expected, ...) + return self:assertGeneralEq(got, expected, ...) +end + +function Tester:assertGeneralNe(got, unexpected, ...) + return self:_eqOrNeq(got, unexpected, true, ...) +end + +function Tester:ne(got, unexpected, ...) + return self:assertGeneralNe(got, unexpected, ...) +end + +function Tester:_eqOrNeq(got, expected, negate, ...) + local tolerance, message = getToleranceAndMessage(0, ...) + local success, subMessage = check.areEq(got, expected, tolerance, negate) + subMessage = subMessage or '' + return self:_assert_sub(success, message .. subMessage) +end + +function Tester:assertlt(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a < b, + string.format('%sLT failed: %s >= %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertgt(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a > b, + string.format('%sGT failed: %s <= %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertle(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a <= b, + string.format('%sLE failed: %s > %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertge(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a >= b, + string.format('%sGE failed: %s < %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertalmosteq(a, b, ...) + local tolerance, message = getToleranceAndMessage(1e-16, ...) + local diff = math.abs(a - b) + return self:_assert_sub( + diff <= tolerance, + string.format( + '%sALMOST_EQ failed: %s ~= %s with tolerance=%s', + message, tostring(a), tostring(b), tostring(tolerance))) +end + +function Tester:asserteq(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a == b, + string.format('%sEQ failed: %s ~= %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertne(a, b, ...) + local message = getMessage(...) + if type(a) == type(b) and type(a) == 'table' or type(a) == 'userdata' then + self:_warning(" :assertne should only be used to compare basic lua " + .. "objects (numbers, booleans, etc). Consider using " + .. "either :assertGeneralNe or :assert(a ~= b).") + end + return self:_assert_sub(a ~= b, + string.format('%sNE failed: %s == %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertTensorEq(ta, tb, ...) + return self:_assertTensorEqOrNeq(ta, tb, false, ...) +end + +function Tester:assertTensorNe(ta, tb, ...) + return self:_assertTensorEqOrNeq(ta, tb, true, ...) +end + +function Tester:_assertTensorEqOrNeq(ta, tb, negate, ...) + assert(torch.isTensor(ta), "First argument should be a Tensor") + assert(torch.isTensor(tb), "Second argument should be a Tensor") + + local tolerance, message = getToleranceAndMessage(0, ...) + local success, subMessage = + check.areTensorsEq(ta, tb, tolerance, negate, + self._assertTensorEqIgnoresDims) + subMessage = subMessage or '' + + if self._assertTensorEqIgnoresDims and (not negate) and success + and not ta:isSameSizeAs(tb) then + self:_warning("Tensors have the same content but different dimensions. " + .. "For backwards compatibility, they are considered equal, " + .. "but this may change in the future. Consider using :eq " + .. "to check for equality instead.") + end + + return self:_assert_sub(success, message .. subMessage) +end + +function Tester:assertTableEq(ta, tb, ...) + return self:_assertTableEqOrNeq(ta, tb, false, ...) +end + +function Tester:assertTableNe(ta, tb, ...) + return self:_assertTableEqOrNeq(ta, tb, true, ...) +end + +function Tester:_assertTableEqOrNeq(ta, tb, negate, ...) + assert(type(ta) == 'table', "First argument should be a Table") + assert(type(tb) == 'table', "Second argument should be a Table") + return self:_eqOrNeq(ta, tb, negate, ...) +end + +function Tester:assertError(f, ...) + return self:assertErrorObj(f, function() return true end, ...) +end + +function Tester:assertNoError(f, ...) + local message = getMessage(...) + local status, err = pcall(f) + return self:_assert_sub(status, + string.format('%sERROR violation: err=%s', message, + tostring(err))) +end + +function Tester:assertErrorMsg(f, errmsg, ...) + return self:assertErrorObj(f, function(err) return err == errmsg end, ...) +end + +function Tester:assertErrorPattern(f, errPattern, ...) + local function errcomp(err) + return string.find(err, errPattern) ~= nil + end + return self:assertErrorObj(f, errcomp, ...) +end + +function Tester:assertErrorObj(f, errcomp, ...) + local message = getMessage(...) + local status, err = pcall(f) + return self:_assert_sub((not status) and errcomp(err), + string.format('%sERROR violation: err=%s', message, + tostring(err))) +end + +function Tester:add(f, name) + if type(f) == "table" then + assert(name == nil, "Name parameter is forbidden for a table of tests, " + .. "since its use is ambiguous") + if f.__isTestSuite then + f = f.__tests + else + self:_warning("Should use TestSuite rather than plain lua table") + end + for i, v in pairs(f) do + -- We forbid nested tests because the "expected" behaviour when a named + -- test is run in the case that the named test is in fact a table of + -- tests is not supported. Similar issue with _setUp and _tearDown + -- functions inside nested tests. + assert(type(v) ~= 'table', "Nested sets of tests are not supported") + self:add(v, i) + end + return self + end + + assert(type(f) == 'function', + "Only tables of functions and functions supported") + + if name == '_setUp' then + assert(not self._setUp, "Only one set-up function allowed") + self._setUp = f + elseif name == '_tearDown' then + assert(not self._tearDown, "Only one tear-down function allowed") + self._tearDown = f + else + name = name or 'unknown' + if self.tests[name] ~= nil then + error('Test with name ' .. name .. ' already exists!') + end + self.tests[name] = f + end + return self +end + +function Tester:disable(testNames) + if type(testNames) == 'string' then + testNames = {testNames} + end + assert(type(testNames) == 'table', "Expecting name or list for disable") + for _, name in ipairs(testNames) do + assert(self.tests[name], "Unrecognized test '" .. name .. "'") + self.disabledTests[name] = true + end + return self +end + +function Tester:run(testNames) + local tests = self:_getTests(testNames) + self.assertionPass = {} + self.assertionFail = {} + self.haveWarning = {} + self.testError = {} + for name in pairs(tests) do + self.assertionPass[name] = 0 + self.assertionFail[name] = 0 + self.testError[name] = 0 + self._warningCount[name] = 0 + end + self:_run(tests) + self:_report(tests) + + -- Throws an error on test failure/error, so that test script returns + -- with nonzero return value. + for name in pairs(tests) do + assert(self.assertionFail[name] == 0, + 'An error was found while running tests!') + assert(self.testError[name] == 0, + 'An error was found while running tests!') + end + + return 0 +end + +local function pluralize(num, str) + local stem = num .. ' ' .. str + if num == 1 then + return stem + else + return stem .. 's' + end +end + +local NCOLS = 80 +local coloured +local enable_colors, c = pcall(require, 'sys.colors') +if arg and enable_colors then -- have we been invoked from the commandline? + coloured = function(str, colour) + return colour .. str .. c.none + end +else + c = {} + coloured = function(str) + return str + end +end + +function Tester:_run(tests) + local ntests = 0 + for _ in pairs(tests) do + ntests = ntests + 1 + end + + local ntestsAsString = string.format('%u', ntests) + local cfmt = string.format('%%%uu/%u ', ntestsAsString:len(), ntestsAsString) + local cfmtlen = ntestsAsString:len() * 2 + 2 + + local function bracket(str) + return '[' .. str .. ']' + end + + io.write('Running ' .. pluralize(ntests, 'test') .. '\n') + local i = 1 + for name, fn in pairs(tests) do + self._currentTestName = name + + -- TODO: compute max length of name and cut it down to size if needed + local strinit = coloured(string.format(cfmt, i), c.cyan) + .. self._currentTestName .. ' ' + .. string.rep('.', + NCOLS - 6 - 2 - + cfmtlen - self._currentTestName:len()) + .. ' ' + io.write(strinit .. bracket(coloured('WAIT', c.cyan))) + io.flush() + + local status, message, pass, skip + if self.disabledTests[name] then + skip = true + else + skip = false + if self._setUp then + self._setUp(name) + end + if self.rethrow then + status = true + local nerr = #self.errors + message = fn() + pass = nerr == #self.errors + else + status, message, pass = self:_pcall(fn) + end + if self._tearDown then + self._tearDown(name) + end + end + + io.write('\r') + io.write(strinit) + + if skip then + io.write(bracket(coloured('SKIP', c.yellow))) + elseif not status then + self.testError[name] = 1 + io.write(bracket(coloured('ERROR', c.magenta))) + elseif not pass then + io.write(bracket(coloured('FAIL', c.red))) + else + io.write(bracket(coloured('PASS', c.green))) + if self._warningCount[name] > 0 then + io.write('\n' .. string.rep(' ', NCOLS - 10)) + io.write(bracket(coloured('+warning', c.yellow))) + end + end + io.write('\n') + io.flush() + + if self.earlyAbort and (i < ntests) and (not status or not pass) + and (not skip) then + io.write('Aborting on first error, not all tests have been executed\n') + break + end + + i = i + 1 + + collectgarbage() + end +end + +function Tester:_pcall(f) + local nerr = #self.errors + local stat, result = xpcall(f, debug.traceback) + if not stat then + self.errors[#self.errors + 1] = + self._currentTestName .. '\n Function call failed\n' .. result .. '\n' + end + return stat, result, stat and (nerr == #self.errors) +end + +function Tester:_getTests(testNames) + if testNames == nil then + return self.tests + end + if type(testNames) == 'string' then + testNames = {testNames} + end + assert(type(testNames) == 'table', + "Only accept a name or table of test names (or nil for all tests)") + + local function getMatchingNames(pattern) + local matchingNames = {} + for name in pairs(self.tests) do + if string.match(name, pattern) then + table.insert(matchingNames, name) + end + end + return matchingNames + end + + local tests = {} + for _, pattern in ipairs(testNames) do + local matchingNames = getMatchingNames(pattern) + assert(#matchingNames > 0, "Couldn't find test '" .. pattern .. "'") + for _, name in ipairs(matchingNames) do + tests[name] = self.tests[name] + end + end + return tests +end + +function Tester:_report(tests) + local ntests = 0 + local nfailures = 0 + local nerrors = 0 + local nskipped = 0 + local nwarnings = 0 + self.countasserts = 0 + for name in pairs(tests) do + ntests = ntests + 1 + self.countasserts = self.countasserts + self.assertionFail[name] + + self.assertionPass[name] + if self.assertionFail[name] > 0 then + nfailures = nfailures + 1 + end + if self.testError[name] > 0 then + nerrors = nerrors + 1 + end + if self._warningCount[name] > 0 then + nwarnings = nwarnings + 1 + end + if self.disabledTests[name] then + nskipped = nskipped + 1 + end + end + if self._warningCount[''] then + nwarnings = nwarnings + self._warningCount[''] + end + + io.write('Completed ' .. pluralize(self.countasserts, 'assert')) + io.write(' in ' .. pluralize(ntests, 'test') .. ' with ') + io.write(coloured(pluralize(nfailures, 'failure'), + nfailures == 0 and c.green or c.red)) + io.write(' and ') + io.write(coloured(pluralize(nerrors, 'error'), + nerrors == 0 and c.green or c.magenta)) + if nwarnings > 0 then + io.write(' and ') + io.write(coloured(pluralize(nwarnings, 'warning'), c.yellow)) + end + if nskipped > 0 then + io.write(' and ') + io.write(coloured(nskipped .. ' disabled', c.yellow)) + end + io.write('\n') + + -- Prints off a message separated by ----- + local haveSection = false + local function addSection(text) + local function printDashes() + io.write(string.rep('-', NCOLS) .. '\n') + end + if not haveSection then + printDashes() + haveSection = true + end + io.write(text .. '\n') + printDashes() + end + + if not self.summaryOnly then + for _, v in ipairs(self.errors) do + addSection(v) + end + for _, v in ipairs(self.warnings) do + addSection(v) + end + end +end + + +--[[ Tests for tensor equality between two tensors of matching sizes and types. + +Tests whether the maximum element-wise difference between `ta` and `tb` is less +than or equal to `tolerance`. + +Arguments: +* `ta` (tensor) +* `tb` (tensor) +* `tolerance` (number) maximum elementwise difference between `ta` and `tb`. +* `negate` (boolean) if true, we invert success and failure. +* `storage` (boolean) if true, we print an error message referring to Storages + rather than Tensors. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areSameFormatTensorsEq(ta, tb, tolerance, negate, storage) + local function ensureHasAbs(t) + -- Byte, Char and Short Tensors don't have abs + return t.abs and t or t:double() + end + + ta = ensureHasAbs(ta) + tb = ensureHasAbs(tb) + + local diff = ta:clone():add(-1, tb):abs() + local err = diff:max() + local success = err <= tolerance + if negate then + success = not success + end + + local errMessage + if not success then + local prefix = storage and 'Storage' or 'Tensor' + local violation = negate and 'NE(==)' or 'EQ(==)' + errMessage = string.format('%s%s violation: max diff=%s, tolerance=%s', + prefix, + violation, + tostring(err), + tostring(tolerance)) + end + + return success, errMessage +end + +--[[ Tests for tensor equality. + +Tests whether the maximum element-wise difference between `ta` and `tb` is less +than or equal to `tolerance`. + +Arguments: +* `ta` (tensor) +* `tb` (tensor) +* `tolerance` (number) maximum elementwise difference between `ta` and `tb`. +* `negate` (boolean) if negate is true, we invert success and failure. +* `ignoreTensorDims` (boolean, default false) if true, then tensors of the same + size but different dimensions can still be considered equal, e.g., + {{1}} == {1}. For backwards compatibility. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areTensorsEq(ta, tb, tolerance, negate, ignoreTensorDims) + ignoreTensorDims = ignoreTensorDims or false + + if not ignoreTensorDims and ta:dim() ~= tb:dim() then + return negate, 'The tensors have different dimensions' + end + + if ta:type() ~= tb:type() then + return negate, 'The tensors have different types' + end + + -- If we are comparing two empty tensors, return true. + -- This is needed because some functions below cannot be applied to tensors + -- of dimension 0. + if ta:dim() == 0 and tb:dim() == 0 then + return not negate, 'Both tensors are empty' + end + + local sameSize + if ignoreTensorDims then + sameSize = ta:nElement() == tb:nElement() + else + sameSize = ta:isSameSizeAs(tb) + end + if not sameSize then + return negate, 'The tensors have different sizes' + end + + return check.areSameFormatTensorsEq(ta, tb, tolerance, negate, false) +end + +local typesMatching = { + ['torch.ByteStorage'] = torch.ByteTensor, + ['torch.CharStorage'] = torch.CharTensor, + ['torch.ShortStorage'] = torch.ShortTensor, + ['torch.IntStorage'] = torch.IntTensor, + ['torch.LongStorage'] = torch.LongTensor, + ['torch.FloatStorage'] = torch.FloatTensor, + ['torch.DoubleStorage'] = torch.DoubleTensor, + ['torch.HalfStorage'] = torch.HalfTensor, +} + +--[[ Tests for storage equality. + +Tests whether the maximum element-wise difference between `sa` and `sb` is less +than or equal to `tolerance`. + +Arguments: +* `sa` (storage) +* `sb` (storage) +* `tolerance` (number) maximum elementwise difference between `a` and `b`. +* `negate` (boolean) if negate is true, we invert success and failure. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areStoragesEq(sa, sb, tolerance, negate) + if sa:size() ~= sb:size() then + return negate, 'The storages have different sizes' + end + + local typeOfsa = torch.type(sa) + local typeOfsb = torch.type(sb) + + if typeOfsa ~= typeOfsb then + return negate, 'The storages have different types' + end + + local ta = typesMatching[typeOfsa](sa) + local tb = typesMatching[typeOfsb](sb) + + return check.areSameFormatTensorsEq(ta, tb, tolerance, negate, true) +end + +--[[ Tests for general (deep) equality. + +The types of `got` and `expected` must match. +Tables are compared recursively. Keys and types of the associated values must +match, recursively. Numbers are compared with the given tolerance. +Torch tensors and storages are compared with the given tolerance on their +elementwise difference. Other types are compared for strict equality with the +regular Lua == operator. + +Arguments: +* `got` +* `expected` +* `tolerance` (number) maximum elementwise difference between `a` and `b`. +* `negate` (boolean) if negate is true, we invert success and failure. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areEq(got, expected, tolerance, negate) + local errMessage + if type(got) ~= type(expected) then + if not negate then + errMessage = 'EQ failed: values have different types (first: ' + .. type(got) .. ', second: ' .. type(expected) .. ')' + end + return negate, errMessage + elseif type(got) == 'number' then + local diff = math.abs(got - expected) + local ok = (diff <= tolerance) + if negate then + ok = not ok + end + if not ok then + if negate then + errMessage = string.format("NE failed: %s == %s", + tostring(got), tostring(expected)) + else + errMessage = string.format("EQ failed: %s ~= %s", + tostring(got), tostring(expected)) + end + if tolerance > 0 then + errMessage = errMessage .. " with tolerance=" .. tostring(tolerance) + end + end + return ok, errMessage + elseif type(expected) == "table" then + return check.areTablesEq(got, expected, tolerance, negate) + elseif torch.isTensor(got) then + return check.areTensorsEq(got, expected, tolerance, negate) + elseif torch.isStorage(got) then + return check.areStoragesEq(got, expected, tolerance, negate) + else + -- Below: we have the same type which is either userdata or a lua type + -- which is not a number. + local ok = (got == expected) + if negate then + ok = not ok + end + if not ok then + if negate then + errMessage = string.format("NE failed: %s (%s) == %s (%s)", + tostring(got), type(got), + tostring(expected), type(expected)) + else + errMessage = string.format("EQ failed: %s (%s) ~= %s (%s)", + tostring(got), type(got), + tostring(expected), type(expected)) + end + end + return ok, errMessage + end +end + +--[[ Tests for (deep) table equality. + +Tables are compared recursively. Keys and types of the associated values must +match, recursively. Numbers are compared with the given tolerance. +Torch tensors and storages are compared with the given tolerance on their +elementwise difference. Other types are compared for strict equality with the +regular Lua == operator. + +Arguments: +* `t1` (table) +* `t2` (table) +* `tolerance` (number) maximum elementwise difference between `a` and `b`. +* `negate` (boolean) if negate is true, we invert success and failure. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areTablesEq(t1, t2, tolerance, negate) + -- Implementation detail: Instead of doing a depth-first table comparison + -- check (for example, using recursion), let's do a breadth-first search + -- using a queue. Why? Because if we have two tables that are quite deep + -- (e.g., a gModule from nngraph), then if they are different then it's + -- more useful to the user to show how they differ at as-shallow-a-depth + -- as possible. + local queue = {} + queue._head = 1 + queue._tail = 1 + function queue.isEmpty() + return queue._tail == queue._head + end + function queue.pop() + queue._head = queue._head + 1 + return queue[queue._head - 1] + end + function queue.push(value) + queue[queue._tail] = value + queue._tail = queue._tail + 1 + end + + queue.push({t1, t2}) + while not queue.isEmpty() do + local location + t1, t2, location = unpack(queue.pop()) + + local function toSublocation(key) + local keyAsString = tostring(key) + return (location and location .. "." .. keyAsString) or keyAsString + end + + for key, value1 in pairs(t1) do + local sublocation = toSublocation(key) + if t2[key] == nil then + return negate, string.format( + "Entry %s missing in second table (is %s in first)", + sublocation, tostring(value1)) + end + local value2 = t2[key] + if type(value1) == 'table' and type(value2) == 'table' then + queue.push({value1, value2, sublocation}) + else + local ok, message = check.areEq(value1, value2, tolerance, false) + if not ok then + message = 'At table location ' .. sublocation .. ': ' .. message + return negate, message + end + end + end + + for key, value2 in pairs(t2) do + local sublocation = toSublocation(key) + if t1[key] == nil then + return negate, string.format( + "Entry %s missing in first table (is %s in second)", + sublocation, tostring(value2)) + end + end + end + return not negate, 'The tables are equal' +end |