aboutsummaryrefslogtreecommitdiffstats
path: root/models/unittest
diff options
context:
space:
mode:
Diffstat (limited to 'models/unittest')
-rw-r--r--models/unittest/consistency.go17
-rw-r--r--models/unittest/fixtures_loader.go84
-rw-r--r--models/unittest/fscopy.go2
-rw-r--r--models/unittest/testdb.go22
-rw-r--r--models/unittest/unit_tests.go4
5 files changed, 75 insertions, 54 deletions
diff --git a/models/unittest/consistency.go b/models/unittest/consistency.go
index 71839001be..364afb5c52 100644
--- a/models/unittest/consistency.go
+++ b/models/unittest/consistency.go
@@ -11,6 +11,7 @@ import (
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"xorm.io/builder"
)
@@ -24,7 +25,7 @@ const (
var consistencyCheckMap = make(map[string]func(t assert.TestingT, bean any))
// CheckConsistencyFor test that all matching database entries are consistent
-func CheckConsistencyFor(t assert.TestingT, beansToCheck ...any) {
+func CheckConsistencyFor(t require.TestingT, beansToCheck ...any) {
for _, bean := range beansToCheck {
sliceType := reflect.SliceOf(reflect.TypeOf(bean))
sliceValue := reflect.MakeSlice(sliceType, 0, 10)
@@ -42,13 +43,11 @@ func CheckConsistencyFor(t assert.TestingT, beansToCheck ...any) {
}
}
-func checkForConsistency(t assert.TestingT, bean any) {
+func checkForConsistency(t require.TestingT, bean any) {
tb, err := db.TableInfo(bean)
assert.NoError(t, err)
f := consistencyCheckMap[tb.Name]
- if f == nil {
- assert.FailNow(t, "unknown bean type: %#v", bean)
- }
+ require.NotNil(t, f, "unknown bean type: %#v", bean)
f(t, bean)
}
@@ -71,8 +70,8 @@ func init() {
AssertCountByCond(t, "follow", builder.Eq{"user_id": user.int("ID")}, user.int("NumFollowing"))
AssertCountByCond(t, "follow", builder.Eq{"follow_id": user.int("ID")}, user.int("NumFollowers"))
if user.int("Type") != modelsUserTypeOrganization {
- assert.EqualValues(t, 0, user.int("NumMembers"), "Unexpected number of members for user id: %d", user.int("ID"))
- assert.EqualValues(t, 0, user.int("NumTeams"), "Unexpected number of teams for user id: %d", user.int("ID"))
+ assert.Equal(t, 0, user.int("NumMembers"), "Unexpected number of members for user id: %d", user.int("ID"))
+ assert.Equal(t, 0, user.int("NumTeams"), "Unexpected number of teams for user id: %d", user.int("ID"))
}
}
@@ -119,7 +118,7 @@ func init() {
assert.EqualValues(t, issue.int("NumComments"), actual, "Unexpected number of comments for issue id: %d", issue.int("ID"))
if issue.bool("IsPull") {
prRow := AssertExistsAndLoadMap(t, "pull_request", builder.Eq{"issue_id": issue.int("ID")})
- assert.EqualValues(t, parseInt(prRow["index"]), issue.int("Index"), "Unexpected index for issue id: %d", issue.int("ID"))
+ assert.Equal(t, parseInt(prRow["index"]), issue.int("Index"), "Unexpected index for issue id: %d", issue.int("ID"))
}
}
@@ -127,7 +126,7 @@ func init() {
pr := reflectionWrap(bean)
issueRow := AssertExistsAndLoadMap(t, "issue", builder.Eq{"id": pr.int("IssueID")})
assert.True(t, parseBool(issueRow["is_pull"]))
- assert.EqualValues(t, parseInt(issueRow["index"]), pr.int("Index"), "Unexpected index for pull request id: %d", pr.int("ID"))
+ assert.Equal(t, parseInt(issueRow["index"]), pr.int("Index"), "Unexpected index for pull request id: %d", pr.int("ID"))
}
checkForMilestoneConsistency := func(t assert.TestingT, bean any) {
diff --git a/models/unittest/fixtures_loader.go b/models/unittest/fixtures_loader.go
index 14686caf63..0560da8349 100644
--- a/models/unittest/fixtures_loader.go
+++ b/models/unittest/fixtures_loader.go
@@ -12,13 +12,17 @@ import (
"slices"
"strings"
+ "code.gitea.io/gitea/models/db"
+
"gopkg.in/yaml.v3"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
-type fixtureItem struct {
- tableName string
+type FixtureItem struct {
+ fileFullPath string
+ tableName string
+
tableNameQuoted string
sqlInserts []string
sqlInsertArgs [][]any
@@ -27,10 +31,11 @@ type fixtureItem struct {
}
type fixturesLoaderInternal struct {
+ xormEngine *xorm.Engine
+ xormTableNames map[string]bool
db *sql.DB
dbType schemas.DBType
- files []string
- fixtures map[string]*fixtureItem
+ fixtures map[string]*FixtureItem
quoteObject func(string) string
paramPlaceholder func(idx int) string
}
@@ -59,29 +64,27 @@ func (f *fixturesLoaderInternal) preprocessFixtureRow(row []map[string]any) (err
return nil
}
-func (f *fixturesLoaderInternal) prepareFixtureItem(file string) (_ *fixtureItem, err error) {
- fixture := &fixtureItem{}
- fixture.tableName, _, _ = strings.Cut(filepath.Base(file), ".")
+func (f *fixturesLoaderInternal) prepareFixtureItem(fixture *FixtureItem) (err error) {
fixture.tableNameQuoted = f.quoteObject(fixture.tableName)
if f.dbType == schemas.MSSQL {
fixture.mssqlHasIdentityColumn, err = f.mssqlTableHasIdentityColumn(f.db, fixture.tableName)
if err != nil {
- return nil, err
+ return err
}
}
- data, err := os.ReadFile(file)
+ data, err := os.ReadFile(fixture.fileFullPath)
if err != nil {
- return nil, fmt.Errorf("failed to read file %q: %w", file, err)
+ return fmt.Errorf("failed to read file %q: %w", fixture.fileFullPath, err)
}
var rows []map[string]any
if err = yaml.Unmarshal(data, &rows); err != nil {
- return nil, fmt.Errorf("failed to unmarshal yaml data from %q: %w", file, err)
+ return fmt.Errorf("failed to unmarshal yaml data from %q: %w", fixture.fileFullPath, err)
}
if err = f.preprocessFixtureRow(rows); err != nil {
- return nil, fmt.Errorf("failed to preprocess fixture rows from %q: %w", file, err)
+ return fmt.Errorf("failed to preprocess fixture rows from %q: %w", fixture.fileFullPath, err)
}
var sqlBuf []byte
@@ -107,19 +110,17 @@ func (f *fixturesLoaderInternal) prepareFixtureItem(file string) (_ *fixtureItem
sqlBuf = sqlBuf[:0]
sqlArguments = sqlArguments[:0]
}
- return fixture, nil
+ return nil
}
-func (f *fixturesLoaderInternal) loadFixtures(tx *sql.Tx, file string) (err error) {
- fixture := f.fixtures[file]
- if fixture == nil {
- if fixture, err = f.prepareFixtureItem(file); err != nil {
+func (f *fixturesLoaderInternal) loadFixtures(tx *sql.Tx, fixture *FixtureItem) (err error) {
+ if fixture.tableNameQuoted == "" {
+ if err = f.prepareFixtureItem(fixture); err != nil {
return err
}
- f.fixtures[file] = fixture
}
- _, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", fixture.tableNameQuoted)) // sqlite3 doesn't support truncate
+ _, err = tx.Exec("DELETE FROM " + fixture.tableNameQuoted) // sqlite3 doesn't support truncate
if err != nil {
return err
}
@@ -147,15 +148,26 @@ func (f *fixturesLoaderInternal) Load() error {
}
defer func() { _ = tx.Rollback() }()
- for _, file := range f.files {
- if err := f.loadFixtures(tx, file); err != nil {
- return fmt.Errorf("failed to load fixtures from %s: %w", file, err)
+ for _, fixture := range f.fixtures {
+ if !f.xormTableNames[fixture.tableName] {
+ continue
+ }
+ if err := f.loadFixtures(tx, fixture); err != nil {
+ return fmt.Errorf("failed to load fixtures from %s: %w", fixture.fileFullPath, err)
}
}
- return tx.Commit()
+ if err = tx.Commit(); err != nil {
+ return err
+ }
+ for xormTableName := range f.xormTableNames {
+ if f.fixtures[xormTableName] == nil {
+ _, _ = f.xormEngine.Exec("DELETE FROM `" + xormTableName + "`")
+ }
+ }
+ return nil
}
-func FixturesFileFullPaths(dir string, files []string) ([]string, error) {
+func FixturesFileFullPaths(dir string, files []string) (map[string]*FixtureItem, error) {
if files != nil && len(files) == 0 {
return nil, nil // load nothing
}
@@ -169,20 +181,25 @@ func FixturesFileFullPaths(dir string, files []string) ([]string, error) {
files = append(files, e.Name())
}
}
- for i, file := range files {
- if !filepath.IsAbs(file) {
- files[i] = filepath.Join(dir, file)
+ fixtureItems := map[string]*FixtureItem{}
+ for _, file := range files {
+ fileFillPath := file
+ if !filepath.IsAbs(fileFillPath) {
+ fileFillPath = filepath.Join(dir, file)
}
+ tableName, _, _ := strings.Cut(filepath.Base(file), ".")
+ fixtureItems[tableName] = &FixtureItem{fileFullPath: fileFillPath, tableName: tableName}
}
- return files, nil
+ return fixtureItems, nil
}
func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, error) {
- files, err := FixturesFileFullPaths(opts.Dir, opts.Files)
+ fixtureItems, err := FixturesFileFullPaths(opts.Dir, opts.Files)
if err != nil {
return nil, fmt.Errorf("failed to get fixtures files: %w", err)
}
- f := &fixturesLoaderInternal{db: x.DB().DB, dbType: x.Dialect().URI().DBType, files: files, fixtures: map[string]*fixtureItem{}}
+
+ f := &fixturesLoaderInternal{xormEngine: x, db: x.DB().DB, dbType: x.Dialect().URI().DBType, fixtures: fixtureItems}
switch f.dbType {
case schemas.SQLITE:
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
@@ -197,5 +214,12 @@ func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, er
f.quoteObject = func(s string) string { return fmt.Sprintf("[%s]", s) }
f.paramPlaceholder = func(idx int) string { return "?" }
}
+
+ xormBeans, _ := db.NamesToBean()
+ f.xormTableNames = map[string]bool{}
+ for _, bean := range xormBeans {
+ f.xormTableNames[db.TableName(bean)] = true
+ }
+
return f, nil
}
diff --git a/models/unittest/fscopy.go b/models/unittest/fscopy.go
index b7ba6b7ef5..98b01815bd 100644
--- a/models/unittest/fscopy.go
+++ b/models/unittest/fscopy.go
@@ -28,7 +28,7 @@ func SyncFile(srcPath, destPath string) error {
}
if src.Size() == dest.Size() &&
- src.ModTime() == dest.ModTime() &&
+ src.ModTime().Equal(dest.ModTime()) &&
src.Mode() == dest.Mode() {
return nil
}
diff --git a/models/unittest/testdb.go b/models/unittest/testdb.go
index 7a9ca9698d..cb60cf5f85 100644
--- a/models/unittest/testdb.go
+++ b/models/unittest/testdb.go
@@ -20,6 +20,7 @@ import (
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/setting/config"
"code.gitea.io/gitea/modules/storage"
+ "code.gitea.io/gitea/modules/tempdir"
"code.gitea.io/gitea/modules/test"
"code.gitea.io/gitea/modules/util"
@@ -35,8 +36,8 @@ func fatalTestError(fmtStr string, args ...any) {
os.Exit(1)
}
-// InitSettings initializes config provider and load common settings for tests
-func InitSettings() {
+// InitSettingsForTesting initializes config provider and load common settings for tests
+func InitSettingsForTesting() {
setting.IsInTesting = true
log.OsExiter = func(code int) {
if code != 0 {
@@ -75,7 +76,7 @@ func MainTest(m *testing.M, testOptsArg ...*TestOptions) {
testOpts := util.OptionalArg(testOptsArg, &TestOptions{})
giteaRoot = test.SetupGiteaRoot()
setting.CustomPath = filepath.Join(giteaRoot, "custom")
- InitSettings()
+ InitSettingsForTesting()
fixturesOpts := FixturesOptions{Dir: filepath.Join(giteaRoot, "models", "fixtures"), Files: testOpts.FixtureFiles}
if err := CreateTestEngine(fixturesOpts); err != nil {
@@ -92,15 +93,19 @@ func MainTest(m *testing.M, testOptsArg ...*TestOptions) {
setting.SSH.Domain = "try.gitea.io"
setting.Database.Type = "sqlite3"
setting.Repository.DefaultBranch = "master" // many test code still assume that default branch is called "master"
- repoRootPath, err := os.MkdirTemp(os.TempDir(), "repos")
+ repoRootPath, cleanup1, err := tempdir.OsTempDir("gitea-test").MkdirTempRandom("repos")
if err != nil {
fatalTestError("TempDir: %v\n", err)
}
+ defer cleanup1()
+
setting.RepoRootPath = repoRootPath
- appDataPath, err := os.MkdirTemp(os.TempDir(), "appdata")
+ appDataPath, cleanup2, err := tempdir.OsTempDir("gitea-test").MkdirTempRandom("appdata")
if err != nil {
fatalTestError("TempDir: %v\n", err)
}
+ defer cleanup2()
+
setting.AppDataPath = appDataPath
setting.AppWorkPath = giteaRoot
setting.StaticRootPath = giteaRoot
@@ -153,13 +158,6 @@ func MainTest(m *testing.M, testOptsArg ...*TestOptions) {
fatalTestError("tear down failed: %v\n", err)
}
}
-
- if err = util.RemoveAll(repoRootPath); err != nil {
- fatalTestError("util.RemoveAll: %v\n", err)
- }
- if err = util.RemoveAll(appDataPath); err != nil {
- fatalTestError("util.RemoveAll: %v\n", err)
- }
os.Exit(exitStatus)
}
diff --git a/models/unittest/unit_tests.go b/models/unittest/unit_tests.go
index 1c5595aef8..4a4cec40ae 100644
--- a/models/unittest/unit_tests.go
+++ b/models/unittest/unit_tests.go
@@ -153,9 +153,9 @@ func DumpQueryResult(t require.TestingT, sqlOrBean any, sqlArgs ...any) {
goDB := x.DB().DB
sql, ok := sqlOrBean.(string)
if !ok {
- sql = fmt.Sprintf("SELECT * FROM %s", db.TableName(sqlOrBean))
+ sql = "SELECT * FROM " + db.TableName(sqlOrBean)
} else if !strings.Contains(sql, " ") {
- sql = fmt.Sprintf("SELECT * FROM %s", sql)
+ sql = "SELECT * FROM " + sql
}
rows, err := goDB.Query(sql, sqlArgs...)
require.NoError(t, err)