aboutsummaryrefslogtreecommitdiffstats
path: root/test/lua/rspamd_assertions.lua
blob: 8e483431de09e056e8cf635eeb782366f7b30026 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
local logger = require "rspamd_logger"
local telescope = require "telescope"
local util  = require 'lua_util'

local function rspamd_assert_equals(tbl)
  return tbl.expect == tbl.actual
end

local function rspamd_assert_equals_msg(_, tbl)
  return logger.slog(
    "Failed asserting that \n  (actual)   : %1 \n equals to\n  (expected) : %2",
    tbl.actual, tbl.expect
  )
end

local function rspamd_assert_table_equals(tbl)
  return util.table_cmp(tbl.expect, tbl.actual)
end

local function rspamd_assert_table_equals_sorted(tbl)
  local expect = util.deepcopy(tbl.expect)
  local actual = util.deepcopy(tbl.actual)
  util.deepsort(expect)
  util.deepsort(actual)
  return util.table_cmp(expect, actual)
end

local function table_keys_sorted(t)
  local keys = {}

  for k,_ in pairs(t) do
    table.insert(keys, k)
  end
  table.sort(keys)
  return keys;
end

local function format_line(level, key, v_expect, v_actual)
  local prefix
  if v_expect == v_actual then
    prefix = string.rep(' ', level * 2 + 1)
    return string.format("%s[%s] = %s", prefix, tostring(key), tostring(v_expect))
  else
    prefix = string.rep(' ', level * 2)
    local ret = {}
    if v_expect then
      ret[#ret + 1] = string.format("-%s[%s] = %s: %s", prefix, tostring(key), type(v_expect), tostring(v_expect))
    end
    if v_actual then
      ret[#ret + 1] = string.format("+%s[%s] = %s: %s", prefix, tostring(key), type(v_actual), tostring(v_actual))
    end
    return table.concat(ret, "\n")
  end
end

local function format_table_begin(level, key)
  local prefix = string.rep(' ', level * 2 + 1)
  return string.format("%s[%s] = {", prefix, tostring(key))
end

local function format_table_end(level)
  local prefix = string.rep(' ', level * 2 + 1)
  return string.format("%s}", prefix)
end

local function rspamd_assert_table_diff_msg(_, tbl)
  local avoid_loops = {}
  local msg = rspamd_assert_equals_msg(_, tbl)

  local diff = {}
  local function recurse(expect, actual, level)
    if avoid_loops[actual] then
      return
    end
    avoid_loops[actual] = true

    local keys_expect = table_keys_sorted(expect)
    local keys_actual = table_keys_sorted(actual)

    local i_k_expect, i_v_expect = next(keys_expect)
    local i_k_actual, i_v_actual = next(keys_actual)

    while i_k_expect and i_k_actual do
      local v_expect = expect[i_v_expect]
      local v_actual = actual[i_v_actual]

      if i_v_expect == i_v_actual then
        -- table keys are the same: compare values
        if type(v_expect) == 'table' and type(v_actual) == 'table' then
          if util.table_cmp(v_expect, v_actual) then
            -- we use the same value for 'actual' and 'expect' as soon as they're equal and don't bother us
            diff[#diff + 1] = format_line(level, i_v_expect, v_expect, v_expect)
          else
            diff[#diff + 1] = format_table_begin(level, i_v_expect)
            recurse(v_expect, v_actual, level + 1)
            diff[#diff + 1] = format_table_end(level)
          end
        else
          diff[#diff + 1] = format_line(level, i_v_expect, v_expect, v_actual)
        end

        i_k_expect, i_v_expect = next(keys_expect, i_k_expect)
        i_k_actual, i_v_actual = next(keys_actual, i_k_actual)
      elseif tostring(v_actual) > tostring(v_expect) then
        diff[#diff + 1] = format_line(level, i_v_expect, v_expect, nil)
        i_k_expect, i_v_expect = next(keys_expect, i_k_expect)
      else
        diff[#diff + 1] = format_line(level, i_v_actual, nil, v_actual)
        i_k_actual, i_v_actual = next(keys_actual, i_k_actual)
      end

    end

    while i_k_expect do
      local v_expect = expect[i_v_expect]
      diff[#diff + 1] = format_line(level, i_v_expect, v_expect, nil)
      i_k_expect, i_v_expect = next(keys_expect, i_k_expect)
    end

    while i_k_actual do
      local v_actual = actual[i_v_actual]
      diff[#diff + 1] = format_line(level, i_v_actual, nil, v_actual)
      i_k_actual, i_v_actual = next(keys_actual, i_k_actual)
    end
  end
  recurse(tbl.expect, tbl.actual, 0)

  return string.format("%s\n===== diff (-expect, +actual) ======\n%s", msg, table.concat(diff, "\n"))
end

telescope.make_assertion("rspamd_eq",       rspamd_assert_equals_msg, rspamd_assert_equals)
-- telescope.make_assertion("rspamd_table_eq", rspamd_assert_equals_msg, rspamd_assert_table_equals)
telescope.make_assertion("rspamd_table_eq", rspamd_assert_table_diff_msg, rspamd_assert_table_equals)
telescope.make_assertion("rspamd_table_eq_sorted", rspamd_assert_table_diff_msg,
    rspamd_assert_table_equals_sorted)