diff options
Diffstat (limited to 'models/unittest')
-rw-r--r-- | models/unittest/consistency.go | 17 | ||||
-rw-r--r-- | models/unittest/fixtures_loader.go | 84 | ||||
-rw-r--r-- | models/unittest/fscopy.go | 2 | ||||
-rw-r--r-- | models/unittest/testdb.go | 22 | ||||
-rw-r--r-- | models/unittest/unit_tests.go | 4 |
5 files changed, 75 insertions, 54 deletions
diff --git a/models/unittest/consistency.go b/models/unittest/consistency.go index 71839001be..364afb5c52 100644 --- a/models/unittest/consistency.go +++ b/models/unittest/consistency.go @@ -11,6 +11,7 @@ import ( "code.gitea.io/gitea/models/db" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "xorm.io/builder" ) @@ -24,7 +25,7 @@ const ( var consistencyCheckMap = make(map[string]func(t assert.TestingT, bean any)) // CheckConsistencyFor test that all matching database entries are consistent -func CheckConsistencyFor(t assert.TestingT, beansToCheck ...any) { +func CheckConsistencyFor(t require.TestingT, beansToCheck ...any) { for _, bean := range beansToCheck { sliceType := reflect.SliceOf(reflect.TypeOf(bean)) sliceValue := reflect.MakeSlice(sliceType, 0, 10) @@ -42,13 +43,11 @@ func CheckConsistencyFor(t assert.TestingT, beansToCheck ...any) { } } -func checkForConsistency(t assert.TestingT, bean any) { +func checkForConsistency(t require.TestingT, bean any) { tb, err := db.TableInfo(bean) assert.NoError(t, err) f := consistencyCheckMap[tb.Name] - if f == nil { - assert.FailNow(t, "unknown bean type: %#v", bean) - } + require.NotNil(t, f, "unknown bean type: %#v", bean) f(t, bean) } @@ -71,8 +70,8 @@ func init() { AssertCountByCond(t, "follow", builder.Eq{"user_id": user.int("ID")}, user.int("NumFollowing")) AssertCountByCond(t, "follow", builder.Eq{"follow_id": user.int("ID")}, user.int("NumFollowers")) if user.int("Type") != modelsUserTypeOrganization { - assert.EqualValues(t, 0, user.int("NumMembers"), "Unexpected number of members for user id: %d", user.int("ID")) - assert.EqualValues(t, 0, user.int("NumTeams"), "Unexpected number of teams for user id: %d", user.int("ID")) + assert.Equal(t, 0, user.int("NumMembers"), "Unexpected number of members for user id: %d", user.int("ID")) + assert.Equal(t, 0, user.int("NumTeams"), "Unexpected number of teams for user id: %d", user.int("ID")) } } @@ -119,7 +118,7 @@ func init() { assert.EqualValues(t, issue.int("NumComments"), actual, "Unexpected number of comments for issue id: %d", issue.int("ID")) if issue.bool("IsPull") { prRow := AssertExistsAndLoadMap(t, "pull_request", builder.Eq{"issue_id": issue.int("ID")}) - assert.EqualValues(t, parseInt(prRow["index"]), issue.int("Index"), "Unexpected index for issue id: %d", issue.int("ID")) + assert.Equal(t, parseInt(prRow["index"]), issue.int("Index"), "Unexpected index for issue id: %d", issue.int("ID")) } } @@ -127,7 +126,7 @@ func init() { pr := reflectionWrap(bean) issueRow := AssertExistsAndLoadMap(t, "issue", builder.Eq{"id": pr.int("IssueID")}) assert.True(t, parseBool(issueRow["is_pull"])) - assert.EqualValues(t, parseInt(issueRow["index"]), pr.int("Index"), "Unexpected index for pull request id: %d", pr.int("ID")) + assert.Equal(t, parseInt(issueRow["index"]), pr.int("Index"), "Unexpected index for pull request id: %d", pr.int("ID")) } checkForMilestoneConsistency := func(t assert.TestingT, bean any) { diff --git a/models/unittest/fixtures_loader.go b/models/unittest/fixtures_loader.go index 14686caf63..0560da8349 100644 --- a/models/unittest/fixtures_loader.go +++ b/models/unittest/fixtures_loader.go @@ -12,13 +12,17 @@ import ( "slices" "strings" + "code.gitea.io/gitea/models/db" + "gopkg.in/yaml.v3" "xorm.io/xorm" "xorm.io/xorm/schemas" ) -type fixtureItem struct { - tableName string +type FixtureItem struct { + fileFullPath string + tableName string + tableNameQuoted string sqlInserts []string sqlInsertArgs [][]any @@ -27,10 +31,11 @@ type fixtureItem struct { } type fixturesLoaderInternal struct { + xormEngine *xorm.Engine + xormTableNames map[string]bool db *sql.DB dbType schemas.DBType - files []string - fixtures map[string]*fixtureItem + fixtures map[string]*FixtureItem quoteObject func(string) string paramPlaceholder func(idx int) string } @@ -59,29 +64,27 @@ func (f *fixturesLoaderInternal) preprocessFixtureRow(row []map[string]any) (err return nil } -func (f *fixturesLoaderInternal) prepareFixtureItem(file string) (_ *fixtureItem, err error) { - fixture := &fixtureItem{} - fixture.tableName, _, _ = strings.Cut(filepath.Base(file), ".") +func (f *fixturesLoaderInternal) prepareFixtureItem(fixture *FixtureItem) (err error) { fixture.tableNameQuoted = f.quoteObject(fixture.tableName) if f.dbType == schemas.MSSQL { fixture.mssqlHasIdentityColumn, err = f.mssqlTableHasIdentityColumn(f.db, fixture.tableName) if err != nil { - return nil, err + return err } } - data, err := os.ReadFile(file) + data, err := os.ReadFile(fixture.fileFullPath) if err != nil { - return nil, fmt.Errorf("failed to read file %q: %w", file, err) + return fmt.Errorf("failed to read file %q: %w", fixture.fileFullPath, err) } var rows []map[string]any if err = yaml.Unmarshal(data, &rows); err != nil { - return nil, fmt.Errorf("failed to unmarshal yaml data from %q: %w", file, err) + return fmt.Errorf("failed to unmarshal yaml data from %q: %w", fixture.fileFullPath, err) } if err = f.preprocessFixtureRow(rows); err != nil { - return nil, fmt.Errorf("failed to preprocess fixture rows from %q: %w", file, err) + return fmt.Errorf("failed to preprocess fixture rows from %q: %w", fixture.fileFullPath, err) } var sqlBuf []byte @@ -107,19 +110,17 @@ func (f *fixturesLoaderInternal) prepareFixtureItem(file string) (_ *fixtureItem sqlBuf = sqlBuf[:0] sqlArguments = sqlArguments[:0] } - return fixture, nil + return nil } -func (f *fixturesLoaderInternal) loadFixtures(tx *sql.Tx, file string) (err error) { - fixture := f.fixtures[file] - if fixture == nil { - if fixture, err = f.prepareFixtureItem(file); err != nil { +func (f *fixturesLoaderInternal) loadFixtures(tx *sql.Tx, fixture *FixtureItem) (err error) { + if fixture.tableNameQuoted == "" { + if err = f.prepareFixtureItem(fixture); err != nil { return err } - f.fixtures[file] = fixture } - _, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", fixture.tableNameQuoted)) // sqlite3 doesn't support truncate + _, err = tx.Exec("DELETE FROM " + fixture.tableNameQuoted) // sqlite3 doesn't support truncate if err != nil { return err } @@ -147,15 +148,26 @@ func (f *fixturesLoaderInternal) Load() error { } defer func() { _ = tx.Rollback() }() - for _, file := range f.files { - if err := f.loadFixtures(tx, file); err != nil { - return fmt.Errorf("failed to load fixtures from %s: %w", file, err) + for _, fixture := range f.fixtures { + if !f.xormTableNames[fixture.tableName] { + continue + } + if err := f.loadFixtures(tx, fixture); err != nil { + return fmt.Errorf("failed to load fixtures from %s: %w", fixture.fileFullPath, err) } } - return tx.Commit() + if err = tx.Commit(); err != nil { + return err + } + for xormTableName := range f.xormTableNames { + if f.fixtures[xormTableName] == nil { + _, _ = f.xormEngine.Exec("DELETE FROM `" + xormTableName + "`") + } + } + return nil } -func FixturesFileFullPaths(dir string, files []string) ([]string, error) { +func FixturesFileFullPaths(dir string, files []string) (map[string]*FixtureItem, error) { if files != nil && len(files) == 0 { return nil, nil // load nothing } @@ -169,20 +181,25 @@ func FixturesFileFullPaths(dir string, files []string) ([]string, error) { files = append(files, e.Name()) } } - for i, file := range files { - if !filepath.IsAbs(file) { - files[i] = filepath.Join(dir, file) + fixtureItems := map[string]*FixtureItem{} + for _, file := range files { + fileFillPath := file + if !filepath.IsAbs(fileFillPath) { + fileFillPath = filepath.Join(dir, file) } + tableName, _, _ := strings.Cut(filepath.Base(file), ".") + fixtureItems[tableName] = &FixtureItem{fileFullPath: fileFillPath, tableName: tableName} } - return files, nil + return fixtureItems, nil } func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, error) { - files, err := FixturesFileFullPaths(opts.Dir, opts.Files) + fixtureItems, err := FixturesFileFullPaths(opts.Dir, opts.Files) if err != nil { return nil, fmt.Errorf("failed to get fixtures files: %w", err) } - f := &fixturesLoaderInternal{db: x.DB().DB, dbType: x.Dialect().URI().DBType, files: files, fixtures: map[string]*fixtureItem{}} + + f := &fixturesLoaderInternal{xormEngine: x, db: x.DB().DB, dbType: x.Dialect().URI().DBType, fixtures: fixtureItems} switch f.dbType { case schemas.SQLITE: f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) } @@ -197,5 +214,12 @@ func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, er f.quoteObject = func(s string) string { return fmt.Sprintf("[%s]", s) } f.paramPlaceholder = func(idx int) string { return "?" } } + + xormBeans, _ := db.NamesToBean() + f.xormTableNames = map[string]bool{} + for _, bean := range xormBeans { + f.xormTableNames[db.TableName(bean)] = true + } + return f, nil } diff --git a/models/unittest/fscopy.go b/models/unittest/fscopy.go index b7ba6b7ef5..98b01815bd 100644 --- a/models/unittest/fscopy.go +++ b/models/unittest/fscopy.go @@ -28,7 +28,7 @@ func SyncFile(srcPath, destPath string) error { } if src.Size() == dest.Size() && - src.ModTime() == dest.ModTime() && + src.ModTime().Equal(dest.ModTime()) && src.Mode() == dest.Mode() { return nil } diff --git a/models/unittest/testdb.go b/models/unittest/testdb.go index 7a9ca9698d..cb60cf5f85 100644 --- a/models/unittest/testdb.go +++ b/models/unittest/testdb.go @@ -20,6 +20,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/setting/config" "code.gitea.io/gitea/modules/storage" + "code.gitea.io/gitea/modules/tempdir" "code.gitea.io/gitea/modules/test" "code.gitea.io/gitea/modules/util" @@ -35,8 +36,8 @@ func fatalTestError(fmtStr string, args ...any) { os.Exit(1) } -// InitSettings initializes config provider and load common settings for tests -func InitSettings() { +// InitSettingsForTesting initializes config provider and load common settings for tests +func InitSettingsForTesting() { setting.IsInTesting = true log.OsExiter = func(code int) { if code != 0 { @@ -75,7 +76,7 @@ func MainTest(m *testing.M, testOptsArg ...*TestOptions) { testOpts := util.OptionalArg(testOptsArg, &TestOptions{}) giteaRoot = test.SetupGiteaRoot() setting.CustomPath = filepath.Join(giteaRoot, "custom") - InitSettings() + InitSettingsForTesting() fixturesOpts := FixturesOptions{Dir: filepath.Join(giteaRoot, "models", "fixtures"), Files: testOpts.FixtureFiles} if err := CreateTestEngine(fixturesOpts); err != nil { @@ -92,15 +93,19 @@ func MainTest(m *testing.M, testOptsArg ...*TestOptions) { setting.SSH.Domain = "try.gitea.io" setting.Database.Type = "sqlite3" setting.Repository.DefaultBranch = "master" // many test code still assume that default branch is called "master" - repoRootPath, err := os.MkdirTemp(os.TempDir(), "repos") + repoRootPath, cleanup1, err := tempdir.OsTempDir("gitea-test").MkdirTempRandom("repos") if err != nil { fatalTestError("TempDir: %v\n", err) } + defer cleanup1() + setting.RepoRootPath = repoRootPath - appDataPath, err := os.MkdirTemp(os.TempDir(), "appdata") + appDataPath, cleanup2, err := tempdir.OsTempDir("gitea-test").MkdirTempRandom("appdata") if err != nil { fatalTestError("TempDir: %v\n", err) } + defer cleanup2() + setting.AppDataPath = appDataPath setting.AppWorkPath = giteaRoot setting.StaticRootPath = giteaRoot @@ -153,13 +158,6 @@ func MainTest(m *testing.M, testOptsArg ...*TestOptions) { fatalTestError("tear down failed: %v\n", err) } } - - if err = util.RemoveAll(repoRootPath); err != nil { - fatalTestError("util.RemoveAll: %v\n", err) - } - if err = util.RemoveAll(appDataPath); err != nil { - fatalTestError("util.RemoveAll: %v\n", err) - } os.Exit(exitStatus) } diff --git a/models/unittest/unit_tests.go b/models/unittest/unit_tests.go index 1c5595aef8..4a4cec40ae 100644 --- a/models/unittest/unit_tests.go +++ b/models/unittest/unit_tests.go @@ -153,9 +153,9 @@ func DumpQueryResult(t require.TestingT, sqlOrBean any, sqlArgs ...any) { goDB := x.DB().DB sql, ok := sqlOrBean.(string) if !ok { - sql = fmt.Sprintf("SELECT * FROM %s", db.TableName(sqlOrBean)) + sql = "SELECT * FROM " + db.TableName(sqlOrBean) } else if !strings.Contains(sql, " ") { - sql = fmt.Sprintf("SELECT * FROM %s", sql) + sql = "SELECT * FROM " + sql } rows, err := goDB.Query(sql, sqlArgs...) require.NoError(t, err) |