You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.go 8.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. // Copyright 2017 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. //nolint:forbidigo
  4. package tests
  5. import (
  6. "context"
  7. "database/sql"
  8. "fmt"
  9. "os"
  10. "path"
  11. "path/filepath"
  12. "testing"
  13. "code.gitea.io/gitea/models/db"
  14. packages_model "code.gitea.io/gitea/models/packages"
  15. "code.gitea.io/gitea/models/unittest"
  16. "code.gitea.io/gitea/modules/base"
  17. "code.gitea.io/gitea/modules/git"
  18. "code.gitea.io/gitea/modules/graceful"
  19. "code.gitea.io/gitea/modules/log"
  20. repo_module "code.gitea.io/gitea/modules/repository"
  21. "code.gitea.io/gitea/modules/setting"
  22. "code.gitea.io/gitea/modules/storage"
  23. "code.gitea.io/gitea/modules/testlogger"
  24. "code.gitea.io/gitea/modules/util"
  25. "code.gitea.io/gitea/routers"
  26. "github.com/stretchr/testify/assert"
  27. )
  28. func exitf(format string, args ...interface{}) {
  29. fmt.Printf(format+"\n", args...)
  30. os.Exit(1)
  31. }
  32. func InitTest(requireGitea bool) {
  33. log.RegisterEventWriter("test", testlogger.NewTestLoggerWriter)
  34. giteaRoot := base.SetupGiteaRoot()
  35. if giteaRoot == "" {
  36. exitf("Environment variable $GITEA_ROOT not set")
  37. }
  38. setting.IsInTesting = true
  39. setting.AppWorkPath = giteaRoot
  40. setting.CustomPath = filepath.Join(setting.AppWorkPath, "custom")
  41. if requireGitea {
  42. giteaBinary := "gitea"
  43. if setting.IsWindows {
  44. giteaBinary += ".exe"
  45. }
  46. setting.AppPath = path.Join(giteaRoot, giteaBinary)
  47. if _, err := os.Stat(setting.AppPath); err != nil {
  48. exitf("Could not find gitea binary at %s", setting.AppPath)
  49. }
  50. }
  51. giteaConf := os.Getenv("GITEA_CONF")
  52. if giteaConf == "" {
  53. // By default, use sqlite.ini for testing, then IDE like GoLand can start the test process with debugger.
  54. // It's easier for developers to debug bugs step by step with a debugger.
  55. // Notice: when doing "ssh push", Gitea executes sub processes, debugger won't work for the sub processes.
  56. giteaConf = "tests/sqlite.ini"
  57. _ = os.Setenv("GITEA_CONF", giteaConf)
  58. fmt.Printf("Environment variable $GITEA_CONF not set, use default: %s\n", giteaConf)
  59. if !setting.EnableSQLite3 {
  60. exitf(`sqlite3 requires: import _ "github.com/mattn/go-sqlite3" or -tags sqlite,sqlite_unlock_notify`)
  61. }
  62. }
  63. if !path.IsAbs(giteaConf) {
  64. setting.CustomConf = filepath.Join(giteaRoot, giteaConf)
  65. } else {
  66. setting.CustomConf = giteaConf
  67. }
  68. unittest.InitSettings()
  69. setting.Repository.DefaultBranch = "master" // many test code still assume that default branch is called "master"
  70. _ = util.RemoveAll(repo_module.LocalCopyPath())
  71. if err := git.InitFull(context.Background()); err != nil {
  72. log.Fatal("git.InitOnceWithSync: %v", err)
  73. }
  74. setting.LoadDBSetting()
  75. if err := storage.Init(); err != nil {
  76. exitf("Init storage failed: %v", err)
  77. }
  78. switch {
  79. case setting.Database.Type.IsMySQL():
  80. connType := "tcp"
  81. if len(setting.Database.Host) > 0 && setting.Database.Host[0] == '/' { // looks like a unix socket
  82. connType = "unix"
  83. }
  84. db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s)/",
  85. setting.Database.User, setting.Database.Passwd, connType, setting.Database.Host))
  86. defer db.Close()
  87. if err != nil {
  88. log.Fatal("sql.Open: %v", err)
  89. }
  90. if _, err = db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", setting.Database.Name)); err != nil {
  91. log.Fatal("db.Exec: %v", err)
  92. }
  93. case setting.Database.Type.IsPostgreSQL():
  94. var db *sql.DB
  95. var err error
  96. if setting.Database.Host[0] == '/' {
  97. db, err = sql.Open("postgres", fmt.Sprintf("postgres://%s:%s@/%s?sslmode=%s&host=%s",
  98. setting.Database.User, setting.Database.Passwd, setting.Database.Name, setting.Database.SSLMode, setting.Database.Host))
  99. } else {
  100. db, err = sql.Open("postgres", fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s",
  101. setting.Database.User, setting.Database.Passwd, setting.Database.Host, setting.Database.Name, setting.Database.SSLMode))
  102. }
  103. defer db.Close()
  104. if err != nil {
  105. log.Fatal("sql.Open: %v", err)
  106. }
  107. dbrows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", setting.Database.Name))
  108. if err != nil {
  109. log.Fatal("db.Query: %v", err)
  110. }
  111. defer dbrows.Close()
  112. if !dbrows.Next() {
  113. if _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", setting.Database.Name)); err != nil {
  114. log.Fatal("db.Exec: CREATE DATABASE: %v", err)
  115. }
  116. }
  117. // Check if we need to setup a specific schema
  118. if len(setting.Database.Schema) == 0 {
  119. break
  120. }
  121. db.Close()
  122. if setting.Database.Host[0] == '/' {
  123. db, err = sql.Open("postgres", fmt.Sprintf("postgres://%s:%s@/%s?sslmode=%s&host=%s",
  124. setting.Database.User, setting.Database.Passwd, setting.Database.Name, setting.Database.SSLMode, setting.Database.Host))
  125. } else {
  126. db, err = sql.Open("postgres", fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s",
  127. setting.Database.User, setting.Database.Passwd, setting.Database.Host, setting.Database.Name, setting.Database.SSLMode))
  128. }
  129. // This is a different db object; requires a different Close()
  130. defer db.Close()
  131. if err != nil {
  132. log.Fatal("sql.Open: %v", err)
  133. }
  134. schrows, err := db.Query(fmt.Sprintf("SELECT 1 FROM information_schema.schemata WHERE schema_name = '%s'", setting.Database.Schema))
  135. if err != nil {
  136. log.Fatal("db.Query: %v", err)
  137. }
  138. defer schrows.Close()
  139. if !schrows.Next() {
  140. // Create and setup a DB schema
  141. if _, err = db.Exec(fmt.Sprintf("CREATE SCHEMA %s", setting.Database.Schema)); err != nil {
  142. log.Fatal("db.Exec: CREATE SCHEMA: %v", err)
  143. }
  144. }
  145. case setting.Database.Type.IsMSSQL():
  146. host, port := setting.ParseMSSQLHostPort(setting.Database.Host)
  147. db, err := sql.Open("mssql", fmt.Sprintf("server=%s; port=%s; database=%s; user id=%s; password=%s;",
  148. host, port, "master", setting.Database.User, setting.Database.Passwd))
  149. if err != nil {
  150. log.Fatal("sql.Open: %v", err)
  151. }
  152. if _, err := db.Exec(fmt.Sprintf("If(db_id(N'%s') IS NULL) BEGIN CREATE DATABASE %s; END;", setting.Database.Name, setting.Database.Name)); err != nil {
  153. log.Fatal("db.Exec: %v", err)
  154. }
  155. defer db.Close()
  156. }
  157. routers.InitWebInstalled(graceful.GetManager().HammerContext())
  158. }
  159. func PrepareTestEnv(t testing.TB, skip ...int) func() {
  160. t.Helper()
  161. ourSkip := 2
  162. if len(skip) > 0 {
  163. ourSkip += skip[0]
  164. }
  165. deferFn := PrintCurrentTest(t, ourSkip)
  166. // load database fixtures
  167. assert.NoError(t, unittest.LoadFixtures())
  168. // load git repo fixtures
  169. assert.NoError(t, util.RemoveAll(setting.RepoRootPath))
  170. assert.NoError(t, unittest.CopyDir(path.Join(filepath.Dir(setting.AppPath), "tests/gitea-repositories-meta"), setting.RepoRootPath))
  171. ownerDirs, err := os.ReadDir(setting.RepoRootPath)
  172. if err != nil {
  173. assert.NoError(t, err, "unable to read the new repo root: %v\n", err)
  174. }
  175. for _, ownerDir := range ownerDirs {
  176. if !ownerDir.Type().IsDir() {
  177. continue
  178. }
  179. repoDirs, err := os.ReadDir(filepath.Join(setting.RepoRootPath, ownerDir.Name()))
  180. if err != nil {
  181. assert.NoError(t, err, "unable to read the new repo root: %v\n", err)
  182. }
  183. for _, repoDir := range repoDirs {
  184. _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "pack"), 0o755)
  185. _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "info"), 0o755)
  186. _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "heads"), 0o755)
  187. _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "tag"), 0o755)
  188. }
  189. }
  190. // load LFS object fixtures
  191. // (LFS storage can be on any of several backends, including remote servers, so we init it with the storage API)
  192. lfsFixtures, err := storage.NewStorage(setting.LocalStorageType, &setting.Storage{
  193. Path: filepath.Join(filepath.Dir(setting.AppPath), "tests/gitea-lfs-meta"),
  194. })
  195. assert.NoError(t, err)
  196. assert.NoError(t, storage.Clean(storage.LFS))
  197. assert.NoError(t, lfsFixtures.IterateObjects("", func(path string, _ storage.Object) error {
  198. _, err := storage.Copy(storage.LFS, path, lfsFixtures, path)
  199. return err
  200. }))
  201. // clear all package data
  202. assert.NoError(t, db.TruncateBeans(db.DefaultContext,
  203. &packages_model.Package{},
  204. &packages_model.PackageVersion{},
  205. &packages_model.PackageFile{},
  206. &packages_model.PackageBlob{},
  207. &packages_model.PackageProperty{},
  208. &packages_model.PackageBlobUpload{},
  209. &packages_model.PackageCleanupRule{},
  210. ))
  211. assert.NoError(t, storage.Clean(storage.Packages))
  212. return deferFn
  213. }
  214. func PrintCurrentTest(t testing.TB, skip ...int) func() {
  215. t.Helper()
  216. actualSkip := 1
  217. if len(skip) > 0 {
  218. actualSkip = skip[0] + 1
  219. }
  220. return testlogger.PrintCurrentTest(t, actualSkip)
  221. }
  222. // Printf takes a format and args and prints the string to os.Stdout
  223. func Printf(format string, args ...interface{}) {
  224. testlogger.Printf(format, args...)
  225. }