You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixtures.go 3.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. // Copyright 2021 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. //nolint:forbidigo
  4. package unittest
  5. import (
  6. "fmt"
  7. "os"
  8. "time"
  9. "code.gitea.io/gitea/models/db"
  10. "code.gitea.io/gitea/modules/auth/password/hash"
  11. "code.gitea.io/gitea/modules/setting"
  12. "github.com/go-testfixtures/testfixtures/v3"
  13. "xorm.io/xorm"
  14. "xorm.io/xorm/schemas"
  15. )
  16. var fixturesLoader *testfixtures.Loader
  17. // GetXORMEngine gets the XORM engine
  18. func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) {
  19. if len(engine) == 1 {
  20. return engine[0]
  21. }
  22. return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
  23. }
  24. // InitFixtures initialize test fixtures for a test database
  25. func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
  26. e := GetXORMEngine(engine...)
  27. var fixtureOptionFiles func(*testfixtures.Loader) error
  28. if opts.Dir != "" {
  29. fixtureOptionFiles = testfixtures.Directory(opts.Dir)
  30. } else {
  31. fixtureOptionFiles = testfixtures.Files(opts.Files...)
  32. }
  33. dialect := "unknown"
  34. switch e.Dialect().URI().DBType {
  35. case schemas.POSTGRES:
  36. dialect = "postgres"
  37. case schemas.MYSQL:
  38. dialect = "mysql"
  39. case schemas.MSSQL:
  40. dialect = "mssql"
  41. case schemas.SQLITE:
  42. dialect = "sqlite3"
  43. default:
  44. fmt.Println("Unsupported RDBMS for integration tests")
  45. os.Exit(1)
  46. }
  47. loaderOptions := []func(loader *testfixtures.Loader) error{
  48. testfixtures.Database(e.DB().DB),
  49. testfixtures.Dialect(dialect),
  50. testfixtures.DangerousSkipTestDatabaseCheck(),
  51. fixtureOptionFiles,
  52. }
  53. if e.Dialect().URI().DBType == schemas.POSTGRES {
  54. loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences())
  55. }
  56. fixturesLoader, err = testfixtures.New(loaderOptions...)
  57. if err != nil {
  58. return err
  59. }
  60. // register the dummy hash algorithm function used in the test fixtures
  61. _ = hash.Register("dummy", hash.NewDummyHasher)
  62. setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
  63. return err
  64. }
  65. // LoadFixtures load fixtures for a test database
  66. func LoadFixtures(engine ...*xorm.Engine) error {
  67. e := GetXORMEngine(engine...)
  68. var err error
  69. // (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times.
  70. for i := 0; i < 5; i++ {
  71. if err = fixturesLoader.Load(); err == nil {
  72. break
  73. }
  74. time.Sleep(200 * time.Millisecond)
  75. }
  76. if err != nil {
  77. fmt.Printf("LoadFixtures failed after retries: %v\n", err)
  78. }
  79. // Now if we're running postgres we need to tell it to update the sequences
  80. if e.Dialect().URI().DBType == schemas.POSTGRES {
  81. results, err := e.QueryString(`SELECT 'SELECT SETVAL(' ||
  82. quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) ||
  83. ', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' ||
  84. quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';'
  85. FROM pg_class AS S,
  86. pg_depend AS D,
  87. pg_class AS T,
  88. pg_attribute AS C,
  89. pg_tables AS PGT
  90. WHERE S.relkind = 'S'
  91. AND S.oid = D.objid
  92. AND D.refobjid = T.oid
  93. AND D.refobjid = C.attrelid
  94. AND D.refobjsubid = C.attnum
  95. AND T.relname = PGT.tablename
  96. ORDER BY S.relname;`)
  97. if err != nil {
  98. fmt.Printf("Failed to generate sequence update: %v\n", err)
  99. return err
  100. }
  101. for _, r := range results {
  102. for _, value := range r {
  103. _, err = e.Exec(value)
  104. if err != nil {
  105. fmt.Printf("Failed to update sequence: %s Error: %v\n", value, err)
  106. return err
  107. }
  108. }
  109. }
  110. }
  111. _ = hash.Register("dummy", hash.NewDummyHasher)
  112. setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
  113. return err
  114. }