aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/torch7/Tester.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-16 16:39:35 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-16 16:39:35 +0100
commit57ad67a4b4560a936f9bc7efa6a1a3778a1372fa (patch)
tree179c5874aef6cc980d6fd3bee1fff1b9ec1c27fe /contrib/torch/torch7/Tester.lua
parent8d080cdc349ee281701ae185d2053314611875da (diff)
downloadrspamd-57ad67a4b4560a936f9bc7efa6a1a3778a1372fa.tar.gz
rspamd-57ad67a4b4560a936f9bc7efa6a1a3778a1372fa.zip
[Feature] Import torch to Rspamd...
Diffstat (limited to 'contrib/torch/torch7/Tester.lua')
-rw-r--r--contrib/torch/torch7/Tester.lua879
1 files changed, 879 insertions, 0 deletions
diff --git a/contrib/torch/torch7/Tester.lua b/contrib/torch/torch7/Tester.lua
new file mode 100644
index 000000000..6509413c2
--- /dev/null
+++ b/contrib/torch/torch7/Tester.lua
@@ -0,0 +1,879 @@
+
+-- Lua 5.2 compatibility
+local unpack = unpack or table.unpack
+
+local check = {} -- helper functions, defined at the bottom of the file
+
+local Tester = torch.class('torch.Tester')
+
+function Tester:__init()
+ self.errors = {}
+ self.tests = {}
+ self.warnings = {}
+ self._warningCount = {}
+ self.disabledTests = {}
+ self._currentTestName = ''
+
+ -- To maintain backwards compatibility (at least for a short while),
+ -- disable exact dimension checking of tensors when :assertTensorEq is
+ -- called. Thus {{1}} == {1} when this flag is true.
+ --
+ -- Note that other methods that suppose tensor checking (such as
+ -- :assertGeneralEq) ignore this flag, since previously they didn't
+ -- exist or support tensor equality checks at all, so there is no
+ -- old code that uses these functions and relies on the behaviour.
+ --
+ -- Note also that if the dimension check fails with this flag is true, then
+ -- will show a warning.
+ self._assertTensorEqIgnoresDims = true
+end
+
+function Tester:setEarlyAbort(earlyAbort)
+ self.earlyAbort = earlyAbort
+end
+
+function Tester:setRethrowErrors(rethrow)
+ self.rethrow = rethrow
+end
+
+function Tester:setSummaryOnly(summaryOnly)
+ self.summaryOnly = summaryOnly
+end
+
+-- Add a success to the test.
+function Tester:_success()
+ local name = self._currentTestName
+ self.assertionPass[name] = self.assertionPass[name] + 1
+ return true
+end
+
+function Tester:_addDebugInfo(message)
+ local ss = debug.traceback('tester', 3) or ''
+ ss = ss:match('.-\n([^\n]+\n[^\n]+)\n[^\n]+xpcall') or ''
+ local name = self._currentTestName
+ return (name ~= '' and name .. '\n' or '') .. message .. '\n' .. ss
+end
+
+-- Add a failure to the test.
+function Tester:_failure(message)
+ if self.rethrow then error(message, 2) end
+ local name = self._currentTestName
+ self.assertionFail[name] = self.assertionFail[name] + 1
+ self.errors[#self.errors + 1] = self:_addDebugInfo(message)
+ return false
+end
+
+-- Add a warning to the test
+function Tester:_warning(message)
+ local name = self._currentTestName
+ self._warningCount[name] = (self._warningCount[name] or 0) + 1
+ self.warnings[#self.warnings + 1] = self:_addDebugInfo(message)
+end
+
+-- Call this during a test run with `condition = true` to log a success, or with
+-- `condition = false` to log a failure (using `message`).
+function Tester:_assert_sub(condition, message)
+ if condition then
+ return self:_success()
+ else
+ return self:_failure(message)
+ end
+end
+
+local function getMessage(message, ...)
+ assert(next{...} == nil, "Unexpected arguments passed to test function")
+ if message then
+ assert(type(message) == 'string', 'message parameter must be a string')
+ if message ~= '' then
+ return message .. '\n'
+ end
+ end
+ return ''
+end
+
+--[[ Historically, some test functions have accepted both a message and a
+tolerance, and some just a message (e.g., assertTableEq). Now assertTableEq
+accepts both a tolerance and a message, so allow the two arguments to be passed
+in either order to maintain backwards compatibility (and more generally,
+for convenience). (We still document the ordering as "tolerance, message" for
+clarity.) This function also sanitizes them (ensures they are non-nil, etc).
+]]
+local function getToleranceAndMessage(defaultTolerance, ...)
+ local args = {...}
+ local message = nil
+ local tolerance = nil
+ for _, a in ipairs(args) do
+ if type(a) == 'string' then
+ if message then
+ error("Unexpected string argument; already have message", a)
+ end
+ message = a .. '\n'
+ elseif type(a) == 'number' then
+ if tolerance then
+ error("Unexpected number argument; already have tolerance", a)
+ end
+ tolerance = a
+ assert(tolerance >= 0, "tolerance cannot be negative")
+ else
+ error("Unrecognized argument; should be a tolerance or message", a)
+ end
+ end
+ message = message or ''
+ tolerance = tolerance or defaultTolerance
+ return tolerance, message
+end
+
+function Tester:assert(condition, ...)
+ local message = getMessage(...)
+ if type(condition) ~= 'boolean' then
+ self:_warning(" :assert should only be used for boolean conditions. "
+ .. "To check for non-nil variables, do this explicitly: "
+ .. "Tester:assert(var ~= nil).")
+ end
+ return self:_assert_sub(condition,
+ string.format('%sBOOL violation condition=%s',
+ message, tostring(condition)))
+end
+
+function Tester:assertGeneralEq(got, expected, ...)
+ return self:_eqOrNeq(got, expected, false, ...)
+end
+
+function Tester:eq(got, expected, ...)
+ return self:assertGeneralEq(got, expected, ...)
+end
+
+function Tester:assertGeneralNe(got, unexpected, ...)
+ return self:_eqOrNeq(got, unexpected, true, ...)
+end
+
+function Tester:ne(got, unexpected, ...)
+ return self:assertGeneralNe(got, unexpected, ...)
+end
+
+function Tester:_eqOrNeq(got, expected, negate, ...)
+ local tolerance, message = getToleranceAndMessage(0, ...)
+ local success, subMessage = check.areEq(got, expected, tolerance, negate)
+ subMessage = subMessage or ''
+ return self:_assert_sub(success, message .. subMessage)
+end
+
+function Tester:assertlt(a, b, ...)
+ local message = getMessage(...)
+ return self:_assert_sub(a < b,
+ string.format('%sLT failed: %s >= %s',
+ message, tostring(a), tostring(b)))
+end
+
+function Tester:assertgt(a, b, ...)
+ local message = getMessage(...)
+ return self:_assert_sub(a > b,
+ string.format('%sGT failed: %s <= %s',
+ message, tostring(a), tostring(b)))
+end
+
+function Tester:assertle(a, b, ...)
+ local message = getMessage(...)
+ return self:_assert_sub(a <= b,
+ string.format('%sLE failed: %s > %s',
+ message, tostring(a), tostring(b)))
+end
+
+function Tester:assertge(a, b, ...)
+ local message = getMessage(...)
+ return self:_assert_sub(a >= b,
+ string.format('%sGE failed: %s < %s',
+ message, tostring(a), tostring(b)))
+end
+
+function Tester:assertalmosteq(a, b, ...)
+ local tolerance, message = getToleranceAndMessage(1e-16, ...)
+ local diff = math.abs(a - b)
+ return self:_assert_sub(
+ diff <= tolerance,
+ string.format(
+ '%sALMOST_EQ failed: %s ~= %s with tolerance=%s',
+ message, tostring(a), tostring(b), tostring(tolerance)))
+end
+
+function Tester:asserteq(a, b, ...)
+ local message = getMessage(...)
+ return self:_assert_sub(a == b,
+ string.format('%sEQ failed: %s ~= %s',
+ message, tostring(a), tostring(b)))
+end
+
+function Tester:assertne(a, b, ...)
+ local message = getMessage(...)
+ if type(a) == type(b) and type(a) == 'table' or type(a) == 'userdata' then
+ self:_warning(" :assertne should only be used to compare basic lua "
+ .. "objects (numbers, booleans, etc). Consider using "
+ .. "either :assertGeneralNe or :assert(a ~= b).")
+ end
+ return self:_assert_sub(a ~= b,
+ string.format('%sNE failed: %s == %s',
+ message, tostring(a), tostring(b)))
+end
+
+function Tester:assertTensorEq(ta, tb, ...)
+ return self:_assertTensorEqOrNeq(ta, tb, false, ...)
+end
+
+function Tester:assertTensorNe(ta, tb, ...)
+ return self:_assertTensorEqOrNeq(ta, tb, true, ...)
+end
+
+function Tester:_assertTensorEqOrNeq(ta, tb, negate, ...)
+ assert(torch.isTensor(ta), "First argument should be a Tensor")
+ assert(torch.isTensor(tb), "Second argument should be a Tensor")
+
+ local tolerance, message = getToleranceAndMessage(0, ...)
+ local success, subMessage =
+ check.areTensorsEq(ta, tb, tolerance, negate,
+ self._assertTensorEqIgnoresDims)
+ subMessage = subMessage or ''
+
+ if self._assertTensorEqIgnoresDims and (not negate) and success
+ and not ta:isSameSizeAs(tb) then
+ self:_warning("Tensors have the same content but different dimensions. "
+ .. "For backwards compatibility, they are considered equal, "
+ .. "but this may change in the future. Consider using :eq "
+ .. "to check for equality instead.")
+ end
+
+ return self:_assert_sub(success, message .. subMessage)
+end
+
+function Tester:assertTableEq(ta, tb, ...)
+ return self:_assertTableEqOrNeq(ta, tb, false, ...)
+end
+
+function Tester:assertTableNe(ta, tb, ...)
+ return self:_assertTableEqOrNeq(ta, tb, true, ...)
+end
+
+function Tester:_assertTableEqOrNeq(ta, tb, negate, ...)
+ assert(type(ta) == 'table', "First argument should be a Table")
+ assert(type(tb) == 'table', "Second argument should be a Table")
+ return self:_eqOrNeq(ta, tb, negate, ...)
+end
+
+function Tester:assertError(f, ...)
+ return self:assertErrorObj(f, function() return true end, ...)
+end
+
+function Tester:assertNoError(f, ...)
+ local message = getMessage(...)
+ local status, err = pcall(f)
+ return self:_assert_sub(status,
+ string.format('%sERROR violation: err=%s', message,
+ tostring(err)))
+end
+
+function Tester:assertErrorMsg(f, errmsg, ...)
+ return self:assertErrorObj(f, function(err) return err == errmsg end, ...)
+end
+
+function Tester:assertErrorPattern(f, errPattern, ...)
+ local function errcomp(err)
+ return string.find(err, errPattern) ~= nil
+ end
+ return self:assertErrorObj(f, errcomp, ...)
+end
+
+function Tester:assertErrorObj(f, errcomp, ...)
+ local message = getMessage(...)
+ local status, err = pcall(f)
+ return self:_assert_sub((not status) and errcomp(err),
+ string.format('%sERROR violation: err=%s', message,
+ tostring(err)))
+end
+
+function Tester:add(f, name)
+ if type(f) == "table" then
+ assert(name == nil, "Name parameter is forbidden for a table of tests, "
+ .. "since its use is ambiguous")
+ if f.__isTestSuite then
+ f = f.__tests
+ else
+ self:_warning("Should use TestSuite rather than plain lua table")
+ end
+ for i, v in pairs(f) do
+ -- We forbid nested tests because the "expected" behaviour when a named
+ -- test is run in the case that the named test is in fact a table of
+ -- tests is not supported. Similar issue with _setUp and _tearDown
+ -- functions inside nested tests.
+ assert(type(v) ~= 'table', "Nested sets of tests are not supported")
+ self:add(v, i)
+ end
+ return self
+ end
+
+ assert(type(f) == 'function',
+ "Only tables of functions and functions supported")
+
+ if name == '_setUp' then
+ assert(not self._setUp, "Only one set-up function allowed")
+ self._setUp = f
+ elseif name == '_tearDown' then
+ assert(not self._tearDown, "Only one tear-down function allowed")
+ self._tearDown = f
+ else
+ name = name or 'unknown'
+ if self.tests[name] ~= nil then
+ error('Test with name ' .. name .. ' already exists!')
+ end
+ self.tests[name] = f
+ end
+ return self
+end
+
+function Tester:disable(testNames)
+ if type(testNames) == 'string' then
+ testNames = {testNames}
+ end
+ assert(type(testNames) == 'table', "Expecting name or list for disable")
+ for _, name in ipairs(testNames) do
+ assert(self.tests[name], "Unrecognized test '" .. name .. "'")
+ self.disabledTests[name] = true
+ end
+ return self
+end
+
+function Tester:run(testNames)
+ local tests = self:_getTests(testNames)
+ self.assertionPass = {}
+ self.assertionFail = {}
+ self.haveWarning = {}
+ self.testError = {}
+ for name in pairs(tests) do
+ self.assertionPass[name] = 0
+ self.assertionFail[name] = 0
+ self.testError[name] = 0
+ self._warningCount[name] = 0
+ end
+ self:_run(tests)
+ self:_report(tests)
+
+ -- Throws an error on test failure/error, so that test script returns
+ -- with nonzero return value.
+ for name in pairs(tests) do
+ assert(self.assertionFail[name] == 0,
+ 'An error was found while running tests!')
+ assert(self.testError[name] == 0,
+ 'An error was found while running tests!')
+ end
+
+ return 0
+end
+
+local function pluralize(num, str)
+ local stem = num .. ' ' .. str
+ if num == 1 then
+ return stem
+ else
+ return stem .. 's'
+ end
+end
+
+local NCOLS = 80
+local coloured
+local enable_colors, c = pcall(require, 'sys.colors')
+if arg and enable_colors then -- have we been invoked from the commandline?
+ coloured = function(str, colour)
+ return colour .. str .. c.none
+ end
+else
+ c = {}
+ coloured = function(str)
+ return str
+ end
+end
+
+function Tester:_run(tests)
+ local ntests = 0
+ for _ in pairs(tests) do
+ ntests = ntests + 1
+ end
+
+ local ntestsAsString = string.format('%u', ntests)
+ local cfmt = string.format('%%%uu/%u ', ntestsAsString:len(), ntestsAsString)
+ local cfmtlen = ntestsAsString:len() * 2 + 2
+
+ local function bracket(str)
+ return '[' .. str .. ']'
+ end
+
+ io.write('Running ' .. pluralize(ntests, 'test') .. '\n')
+ local i = 1
+ for name, fn in pairs(tests) do
+ self._currentTestName = name
+
+ -- TODO: compute max length of name and cut it down to size if needed
+ local strinit = coloured(string.format(cfmt, i), c.cyan)
+ .. self._currentTestName .. ' '
+ .. string.rep('.',
+ NCOLS - 6 - 2 -
+ cfmtlen - self._currentTestName:len())
+ .. ' '
+ io.write(strinit .. bracket(coloured('WAIT', c.cyan)))
+ io.flush()
+
+ local status, message, pass, skip
+ if self.disabledTests[name] then
+ skip = true
+ else
+ skip = false
+ if self._setUp then
+ self._setUp(name)
+ end
+ if self.rethrow then
+ status = true
+ local nerr = #self.errors
+ message = fn()
+ pass = nerr == #self.errors
+ else
+ status, message, pass = self:_pcall(fn)
+ end
+ if self._tearDown then
+ self._tearDown(name)
+ end
+ end
+
+ io.write('\r')
+ io.write(strinit)
+
+ if skip then
+ io.write(bracket(coloured('SKIP', c.yellow)))
+ elseif not status then
+ self.testError[name] = 1
+ io.write(bracket(coloured('ERROR', c.magenta)))
+ elseif not pass then
+ io.write(bracket(coloured('FAIL', c.red)))
+ else
+ io.write(bracket(coloured('PASS', c.green)))
+ if self._warningCount[name] > 0 then
+ io.write('\n' .. string.rep(' ', NCOLS - 10))
+ io.write(bracket(coloured('+warning', c.yellow)))
+ end
+ end
+ io.write('\n')
+ io.flush()
+
+ if self.earlyAbort and (i < ntests) and (not status or not pass)
+ and (not skip) then
+ io.write('Aborting on first error, not all tests have been executed\n')
+ break
+ end
+
+ i = i + 1
+
+ collectgarbage()
+ end
+end
+
+function Tester:_pcall(f)
+ local nerr = #self.errors
+ local stat, result = xpcall(f, debug.traceback)
+ if not stat then
+ self.errors[#self.errors + 1] =
+ self._currentTestName .. '\n Function call failed\n' .. result .. '\n'
+ end
+ return stat, result, stat and (nerr == #self.errors)
+end
+
+function Tester:_getTests(testNames)
+ if testNames == nil then
+ return self.tests
+ end
+ if type(testNames) == 'string' then
+ testNames = {testNames}
+ end
+ assert(type(testNames) == 'table',
+ "Only accept a name or table of test names (or nil for all tests)")
+
+ local function getMatchingNames(pattern)
+ local matchingNames = {}
+ for name in pairs(self.tests) do
+ if string.match(name, pattern) then
+ table.insert(matchingNames, name)
+ end
+ end
+ return matchingNames
+ end
+
+ local tests = {}
+ for _, pattern in ipairs(testNames) do
+ local matchingNames = getMatchingNames(pattern)
+ assert(#matchingNames > 0, "Couldn't find test '" .. pattern .. "'")
+ for _, name in ipairs(matchingNames) do
+ tests[name] = self.tests[name]
+ end
+ end
+ return tests
+end
+
+function Tester:_report(tests)
+ local ntests = 0
+ local nfailures = 0
+ local nerrors = 0
+ local nskipped = 0
+ local nwarnings = 0
+ self.countasserts = 0
+ for name in pairs(tests) do
+ ntests = ntests + 1
+ self.countasserts = self.countasserts + self.assertionFail[name]
+ + self.assertionPass[name]
+ if self.assertionFail[name] > 0 then
+ nfailures = nfailures + 1
+ end
+ if self.testError[name] > 0 then
+ nerrors = nerrors + 1
+ end
+ if self._warningCount[name] > 0 then
+ nwarnings = nwarnings + 1
+ end
+ if self.disabledTests[name] then
+ nskipped = nskipped + 1
+ end
+ end
+ if self._warningCount[''] then
+ nwarnings = nwarnings + self._warningCount['']
+ end
+
+ io.write('Completed ' .. pluralize(self.countasserts, 'assert'))
+ io.write(' in ' .. pluralize(ntests, 'test') .. ' with ')
+ io.write(coloured(pluralize(nfailures, 'failure'),
+ nfailures == 0 and c.green or c.red))
+ io.write(' and ')
+ io.write(coloured(pluralize(nerrors, 'error'),
+ nerrors == 0 and c.green or c.magenta))
+ if nwarnings > 0 then
+ io.write(' and ')
+ io.write(coloured(pluralize(nwarnings, 'warning'), c.yellow))
+ end
+ if nskipped > 0 then
+ io.write(' and ')
+ io.write(coloured(nskipped .. ' disabled', c.yellow))
+ end
+ io.write('\n')
+
+ -- Prints off a message separated by -----
+ local haveSection = false
+ local function addSection(text)
+ local function printDashes()
+ io.write(string.rep('-', NCOLS) .. '\n')
+ end
+ if not haveSection then
+ printDashes()
+ haveSection = true
+ end
+ io.write(text .. '\n')
+ printDashes()
+ end
+
+ if not self.summaryOnly then
+ for _, v in ipairs(self.errors) do
+ addSection(v)
+ end
+ for _, v in ipairs(self.warnings) do
+ addSection(v)
+ end
+ end
+end
+
+
+--[[ Tests for tensor equality between two tensors of matching sizes and types.
+
+Tests whether the maximum element-wise difference between `ta` and `tb` is less
+than or equal to `tolerance`.
+
+Arguments:
+* `ta` (tensor)
+* `tb` (tensor)
+* `tolerance` (number) maximum elementwise difference between `ta` and `tb`.
+* `negate` (boolean) if true, we invert success and failure.
+* `storage` (boolean) if true, we print an error message referring to Storages
+ rather than Tensors.
+
+Returns:
+1. success, boolean that indicates success
+2. failure_message, string or nil
+]]
+function check.areSameFormatTensorsEq(ta, tb, tolerance, negate, storage)
+ local function ensureHasAbs(t)
+ -- Byte, Char and Short Tensors don't have abs
+ return t.abs and t or t:double()
+ end
+
+ ta = ensureHasAbs(ta)
+ tb = ensureHasAbs(tb)
+
+ local diff = ta:clone():add(-1, tb):abs()
+ local err = diff:max()
+ local success = err <= tolerance
+ if negate then
+ success = not success
+ end
+
+ local errMessage
+ if not success then
+ local prefix = storage and 'Storage' or 'Tensor'
+ local violation = negate and 'NE(==)' or 'EQ(==)'
+ errMessage = string.format('%s%s violation: max diff=%s, tolerance=%s',
+ prefix,
+ violation,
+ tostring(err),
+ tostring(tolerance))
+ end
+
+ return success, errMessage
+end
+
+--[[ Tests for tensor equality.
+
+Tests whether the maximum element-wise difference between `ta` and `tb` is less
+than or equal to `tolerance`.
+
+Arguments:
+* `ta` (tensor)
+* `tb` (tensor)
+* `tolerance` (number) maximum elementwise difference between `ta` and `tb`.
+* `negate` (boolean) if negate is true, we invert success and failure.
+* `ignoreTensorDims` (boolean, default false) if true, then tensors of the same
+ size but different dimensions can still be considered equal, e.g.,
+ {{1}} == {1}. For backwards compatibility.
+
+Returns:
+1. success, boolean that indicates success
+2. failure_message, string or nil
+]]
+function check.areTensorsEq(ta, tb, tolerance, negate, ignoreTensorDims)
+ ignoreTensorDims = ignoreTensorDims or false
+
+ if not ignoreTensorDims and ta:dim() ~= tb:dim() then
+ return negate, 'The tensors have different dimensions'
+ end
+
+ if ta:type() ~= tb:type() then
+ return negate, 'The tensors have different types'
+ end
+
+ -- If we are comparing two empty tensors, return true.
+ -- This is needed because some functions below cannot be applied to tensors
+ -- of dimension 0.
+ if ta:dim() == 0 and tb:dim() == 0 then
+ return not negate, 'Both tensors are empty'
+ end
+
+ local sameSize
+ if ignoreTensorDims then
+ sameSize = ta:nElement() == tb:nElement()
+ else
+ sameSize = ta:isSameSizeAs(tb)
+ end
+ if not sameSize then
+ return negate, 'The tensors have different sizes'
+ end
+
+ return check.areSameFormatTensorsEq(ta, tb, tolerance, negate, false)
+end
+
+local typesMatching = {
+ ['torch.ByteStorage'] = torch.ByteTensor,
+ ['torch.CharStorage'] = torch.CharTensor,
+ ['torch.ShortStorage'] = torch.ShortTensor,
+ ['torch.IntStorage'] = torch.IntTensor,
+ ['torch.LongStorage'] = torch.LongTensor,
+ ['torch.FloatStorage'] = torch.FloatTensor,
+ ['torch.DoubleStorage'] = torch.DoubleTensor,
+ ['torch.HalfStorage'] = torch.HalfTensor,
+}
+
+--[[ Tests for storage equality.
+
+Tests whether the maximum element-wise difference between `sa` and `sb` is less
+than or equal to `tolerance`.
+
+Arguments:
+* `sa` (storage)
+* `sb` (storage)
+* `tolerance` (number) maximum elementwise difference between `a` and `b`.
+* `negate` (boolean) if negate is true, we invert success and failure.
+
+Returns:
+1. success, boolean that indicates success
+2. failure_message, string or nil
+]]
+function check.areStoragesEq(sa, sb, tolerance, negate)
+ if sa:size() ~= sb:size() then
+ return negate, 'The storages have different sizes'
+ end
+
+ local typeOfsa = torch.type(sa)
+ local typeOfsb = torch.type(sb)
+
+ if typeOfsa ~= typeOfsb then
+ return negate, 'The storages have different types'
+ end
+
+ local ta = typesMatching[typeOfsa](sa)
+ local tb = typesMatching[typeOfsb](sb)
+
+ return check.areSameFormatTensorsEq(ta, tb, tolerance, negate, true)
+end
+
+--[[ Tests for general (deep) equality.
+
+The types of `got` and `expected` must match.
+Tables are compared recursively. Keys and types of the associated values must
+match, recursively. Numbers are compared with the given tolerance.
+Torch tensors and storages are compared with the given tolerance on their
+elementwise difference. Other types are compared for strict equality with the
+regular Lua == operator.
+
+Arguments:
+* `got`
+* `expected`
+* `tolerance` (number) maximum elementwise difference between `a` and `b`.
+* `negate` (boolean) if negate is true, we invert success and failure.
+
+Returns:
+1. success, boolean that indicates success
+2. failure_message, string or nil
+]]
+function check.areEq(got, expected, tolerance, negate)
+ local errMessage
+ if type(got) ~= type(expected) then
+ if not negate then
+ errMessage = 'EQ failed: values have different types (first: '
+ .. type(got) .. ', second: ' .. type(expected) .. ')'
+ end
+ return negate, errMessage
+ elseif type(got) == 'number' then
+ local diff = math.abs(got - expected)
+ local ok = (diff <= tolerance)
+ if negate then
+ ok = not ok
+ end
+ if not ok then
+ if negate then
+ errMessage = string.format("NE failed: %s == %s",
+ tostring(got), tostring(expected))
+ else
+ errMessage = string.format("EQ failed: %s ~= %s",
+ tostring(got), tostring(expected))
+ end
+ if tolerance > 0 then
+ errMessage = errMessage .. " with tolerance=" .. tostring(tolerance)
+ end
+ end
+ return ok, errMessage
+ elseif type(expected) == "table" then
+ return check.areTablesEq(got, expected, tolerance, negate)
+ elseif torch.isTensor(got) then
+ return check.areTensorsEq(got, expected, tolerance, negate)
+ elseif torch.isStorage(got) then
+ return check.areStoragesEq(got, expected, tolerance, negate)
+ else
+ -- Below: we have the same type which is either userdata or a lua type
+ -- which is not a number.
+ local ok = (got == expected)
+ if negate then
+ ok = not ok
+ end
+ if not ok then
+ if negate then
+ errMessage = string.format("NE failed: %s (%s) == %s (%s)",
+ tostring(got), type(got),
+ tostring(expected), type(expected))
+ else
+ errMessage = string.format("EQ failed: %s (%s) ~= %s (%s)",
+ tostring(got), type(got),
+ tostring(expected), type(expected))
+ end
+ end
+ return ok, errMessage
+ end
+end
+
+--[[ Tests for (deep) table equality.
+
+Tables are compared recursively. Keys and types of the associated values must
+match, recursively. Numbers are compared with the given tolerance.
+Torch tensors and storages are compared with the given tolerance on their
+elementwise difference. Other types are compared for strict equality with the
+regular Lua == operator.
+
+Arguments:
+* `t1` (table)
+* `t2` (table)
+* `tolerance` (number) maximum elementwise difference between `a` and `b`.
+* `negate` (boolean) if negate is true, we invert success and failure.
+
+Returns:
+1. success, boolean that indicates success
+2. failure_message, string or nil
+]]
+function check.areTablesEq(t1, t2, tolerance, negate)
+ -- Implementation detail: Instead of doing a depth-first table comparison
+ -- check (for example, using recursion), let's do a breadth-first search
+ -- using a queue. Why? Because if we have two tables that are quite deep
+ -- (e.g., a gModule from nngraph), then if they are different then it's
+ -- more useful to the user to show how they differ at as-shallow-a-depth
+ -- as possible.
+ local queue = {}
+ queue._head = 1
+ queue._tail = 1
+ function queue.isEmpty()
+ return queue._tail == queue._head
+ end
+ function queue.pop()
+ queue._head = queue._head + 1
+ return queue[queue._head - 1]
+ end
+ function queue.push(value)
+ queue[queue._tail] = value
+ queue._tail = queue._tail + 1
+ end
+
+ queue.push({t1, t2})
+ while not queue.isEmpty() do
+ local location
+ t1, t2, location = unpack(queue.pop())
+
+ local function toSublocation(key)
+ local keyAsString = tostring(key)
+ return (location and location .. "." .. keyAsString) or keyAsString
+ end
+
+ for key, value1 in pairs(t1) do
+ local sublocation = toSublocation(key)
+ if t2[key] == nil then
+ return negate, string.format(
+ "Entry %s missing in second table (is %s in first)",
+ sublocation, tostring(value1))
+ end
+ local value2 = t2[key]
+ if type(value1) == 'table' and type(value2) == 'table' then
+ queue.push({value1, value2, sublocation})
+ else
+ local ok, message = check.areEq(value1, value2, tolerance, false)
+ if not ok then
+ message = 'At table location ' .. sublocation .. ': ' .. message
+ return negate, message
+ end
+ end
+ end
+
+ for key, value2 in pairs(t2) do
+ local sublocation = toSublocation(key)
+ if t1[key] == nil then
+ return negate, string.format(
+ "Entry %s missing in first table (is %s in second)",
+ sublocation, tostring(value2))
+ end
+ end
+ end
+ return not negate, 'The tables are equal'
+end