aboutsummaryrefslogtreecommitdiffstats
path: root/models/unittest/unit_tests.go
blob: d47bceea1ea1358b9ed0083322529a9080b95e18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright 2016 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package unittest

import (
	"math"

	"code.gitea.io/gitea/models/db"

	"github.com/stretchr/testify/assert"
	"xorm.io/builder"
)

// Code in this file is mainly used by unittest.CheckConsistencyFor, which is not in the unit test for various reasons.
// In the future if we can decouple CheckConsistencyFor into separate unit test code, then this file can be moved into unittest package too.

// NonexistentID an ID that will never exist
const NonexistentID = int64(math.MaxInt64)

type testCond struct {
	query any
	args  []any
}

type testOrderBy string

// Cond create a condition with arguments for a test
func Cond(query any, args ...any) any {
	return &testCond{query: query, args: args}
}

// OrderBy creates "ORDER BY" a test query
func OrderBy(orderBy string) any {
	return testOrderBy(orderBy)
}

func whereOrderConditions(e db.Engine, conditions []any) db.Engine {
	orderBy := "id" // query must have the "ORDER BY", otherwise the result is not deterministic
	for _, condition := range conditions {
		switch cond := condition.(type) {
		case *testCond:
			e = e.Where(cond.query, cond.args...)
		case testOrderBy:
			orderBy = string(cond)
		default:
			e = e.Where(cond)
		}
	}
	return e.OrderBy(orderBy)
}

// LoadBeanIfExists loads beans from fixture database if exist
func LoadBeanIfExists(bean any, conditions ...any) (bool, error) {
	e := db.GetEngine(db.DefaultContext)
	return whereOrderConditions(e, conditions).Get(bean)
}

// BeanExists for testing, check if a bean exists
func BeanExists(t assert.TestingT, bean any, conditions ...any) bool {
	exists, err := LoadBeanIfExists(bean, conditions...)
	assert.NoError(t, err)
	return exists
}

// AssertExistsAndLoadBean assert that a bean exists and load it from the test database
func AssertExistsAndLoadBean[T any](t assert.TestingT, bean T, conditions ...any) T {
	exists, err := LoadBeanIfExists(bean, conditions...)
	assert.NoError(t, err)
	assert.True(t, exists,
		"Expected to find %+v (of type %T, with conditions %+v), but did not",
		bean, bean, conditions)
	return bean
}

// AssertExistsAndLoadMap assert that a row exists and load it from the test database
func AssertExistsAndLoadMap(t assert.TestingT, table string, conditions ...any) map[string]string {
	e := db.GetEngine(db.DefaultContext).Table(table)
	res, err := whereOrderConditions(e, conditions).Query()
	assert.NoError(t, err)
	assert.True(t, len(res) == 1,
		"Expected to find one row in %s (with conditions %+v), but found %d",
		table, conditions, len(res),
	)

	if len(res) == 1 {
		rec := map[string]string{}
		for k, v := range res[0] {
			rec[k] = string(v)
		}
		return rec
	}
	return nil
}

// GetCount get the count of a bean
func GetCount(t assert.TestingT, bean any, conditions ...any) int {
	e := db.GetEngine(db.DefaultContext)
	for _, condition := range conditions {
		switch cond := condition.(type) {
		case *testCond:
			e = e.Where(cond.query, cond.args...)
		default:
			e = e.Where(cond)
		}
	}
	count, err := e.Count(bean)
	assert.NoError(t, err)
	return int(count)
}

// AssertNotExistsBean assert that a bean does not exist in the test database
func AssertNotExistsBean(t assert.TestingT, bean any, conditions ...any) {
	exists, err := LoadBeanIfExists(bean, conditions...)
	assert.NoError(t, err)
	assert.False(t, exists)
}

// AssertExistsIf asserts that a bean exists or does not exist, depending on
// what is expected.
func AssertExistsIf(t assert.TestingT, expected bool, bean any, conditions ...any) {
	exists, err := LoadBeanIfExists(bean, conditions...)
	assert.NoError(t, err)
	assert.Equal(t, expected, exists)
}

// AssertSuccessfulInsert assert that beans is successfully inserted
func AssertSuccessfulInsert(t assert.TestingT, beans ...any) {
	err := db.Insert(db.DefaultContext, beans...)
	assert.NoError(t, err)
}

// AssertCount assert the count of a bean
func AssertCount(t assert.TestingT, bean, expected any) {
	assert.EqualValues(t, expected, GetCount(t, bean))
}

// AssertInt64InRange assert value is in range [low, high]
func AssertInt64InRange(t assert.TestingT, low, high, value int64) {
	assert.True(t, value >= low && value <= high,
		"Expected value in range [%d, %d], found %d", low, high, value)
}

// GetCountByCond get the count of database entries matching bean
func GetCountByCond(t assert.TestingT, tableName string, cond builder.Cond) int64 {
	e := db.GetEngine(db.DefaultContext)
	count, err := e.Table(tableName).Where(cond).Count()
	assert.NoError(t, err)
	return count
}

// AssertCountByCond test the count of database entries matching bean
func AssertCountByCond(t assert.TestingT, tableName string, cond builder.Cond, expected int) {
	assert.EqualValues(t, expected, GetCountByCond(t, tableName, cond),
		"Failed consistency test, the counted bean (of table %s) was %+v", tableName, cond)
}