You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fixtures.go 3.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. // Copyright 2021 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package unittest
  5. import (
  6. "fmt"
  7. "os"
  8. "time"
  9. "code.gitea.io/gitea/models/db"
  10. "github.com/go-testfixtures/testfixtures/v3"
  11. "xorm.io/xorm"
  12. "xorm.io/xorm/schemas"
  13. )
  14. var fixtures *testfixtures.Loader
  15. // GetXORMEngine gets the XORM engine
  16. func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) {
  17. if len(engine) == 1 {
  18. return engine[0]
  19. }
  20. return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
  21. }
  22. // InitFixtures initialize test fixtures for a test database
  23. func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
  24. e := GetXORMEngine(engine...)
  25. var testfiles func(*testfixtures.Loader) error
  26. if opts.Dir != "" {
  27. testfiles = testfixtures.Directory(opts.Dir)
  28. } else {
  29. testfiles = testfixtures.Files(opts.Files...)
  30. }
  31. dialect := "unknown"
  32. switch e.Dialect().URI().DBType {
  33. case schemas.POSTGRES:
  34. dialect = "postgres"
  35. case schemas.MYSQL:
  36. dialect = "mysql"
  37. case schemas.MSSQL:
  38. dialect = "mssql"
  39. case schemas.SQLITE:
  40. dialect = "sqlite3"
  41. default:
  42. fmt.Println("Unsupported RDBMS for integration tests")
  43. os.Exit(1)
  44. }
  45. loaderOptions := []func(loader *testfixtures.Loader) error{
  46. testfixtures.Database(e.DB().DB),
  47. testfixtures.Dialect(dialect),
  48. testfixtures.DangerousSkipTestDatabaseCheck(),
  49. testfiles,
  50. }
  51. if e.Dialect().URI().DBType == schemas.POSTGRES {
  52. loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences())
  53. }
  54. fixtures, err = testfixtures.New(loaderOptions...)
  55. if err != nil {
  56. return err
  57. }
  58. return err
  59. }
  60. // LoadFixtures load fixtures for a test database
  61. func LoadFixtures(engine ...*xorm.Engine) error {
  62. e := GetXORMEngine(engine...)
  63. var err error
  64. // Database transaction conflicts could occur and result in ROLLBACK
  65. // As a simple workaround, we just retry 20 times.
  66. for i := 0; i < 20; i++ {
  67. err = fixtures.Load()
  68. if err == nil {
  69. break
  70. }
  71. time.Sleep(200 * time.Millisecond)
  72. }
  73. if err != nil {
  74. fmt.Printf("LoadFixtures failed after retries: %v\n", err)
  75. }
  76. // Now if we're running postgres we need to tell it to update the sequences
  77. if e.Dialect().URI().DBType == schemas.POSTGRES {
  78. results, err := e.QueryString(`SELECT 'SELECT SETVAL(' ||
  79. quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) ||
  80. ', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' ||
  81. quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';'
  82. FROM pg_class AS S,
  83. pg_depend AS D,
  84. pg_class AS T,
  85. pg_attribute AS C,
  86. pg_tables AS PGT
  87. WHERE S.relkind = 'S'
  88. AND S.oid = D.objid
  89. AND D.refobjid = T.oid
  90. AND D.refobjid = C.attrelid
  91. AND D.refobjsubid = C.attnum
  92. AND T.relname = PGT.tablename
  93. ORDER BY S.relname;`)
  94. if err != nil {
  95. fmt.Printf("Failed to generate sequence update: %v\n", err)
  96. return err
  97. }
  98. for _, r := range results {
  99. for _, value := range r {
  100. _, err = e.Exec(value)
  101. if err != nil {
  102. fmt.Printf("Failed to update sequence: %s Error: %v\n", value, err)
  103. return err
  104. }
  105. }
  106. }
  107. }
  108. return err
  109. }