You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

testfixtures.go 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. package testfixtures // import "github.com/go-testfixtures/testfixtures/v3"
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "fmt"
  6. "io/ioutil"
  7. "os"
  8. "path"
  9. "path/filepath"
  10. "regexp"
  11. "strings"
  12. "text/template"
  13. "time"
  14. "gopkg.in/yaml.v2"
  15. )
  16. // Loader is the responsible to loading fixtures.
  17. type Loader struct {
  18. db *sql.DB
  19. helper helper
  20. fixturesFiles []*fixtureFile
  21. skipTestDatabaseCheck bool
  22. location *time.Location
  23. template bool
  24. templateFuncs template.FuncMap
  25. templateLeftDelim string
  26. templateRightDelim string
  27. templateOptions []string
  28. templateData interface{}
  29. }
  30. type fixtureFile struct {
  31. path string
  32. fileName string
  33. content []byte
  34. insertSQLs []insertSQL
  35. }
  36. type insertSQL struct {
  37. sql string
  38. params []interface{}
  39. }
  40. var (
  41. testDatabaseRegexp = regexp.MustCompile("(?i)test")
  42. errDatabaseIsRequired = fmt.Errorf("testfixtures: database is required")
  43. errDialectIsRequired = fmt.Errorf("testfixtures: dialect is required")
  44. )
  45. // New instantiates a new Loader instance. The "Database" and "Driver"
  46. // options are required.
  47. func New(options ...func(*Loader) error) (*Loader, error) {
  48. l := &Loader{
  49. templateLeftDelim: "{{",
  50. templateRightDelim: "}}",
  51. templateOptions: []string{"missingkey=zero"},
  52. }
  53. for _, option := range options {
  54. if err := option(l); err != nil {
  55. return nil, err
  56. }
  57. }
  58. if l.db == nil {
  59. return nil, errDatabaseIsRequired
  60. }
  61. if l.helper == nil {
  62. return nil, errDialectIsRequired
  63. }
  64. if err := l.helper.init(l.db); err != nil {
  65. return nil, err
  66. }
  67. if err := l.buildInsertSQLs(); err != nil {
  68. return nil, err
  69. }
  70. return l, nil
  71. }
  72. // Database sets an existing sql.DB instant to Loader.
  73. func Database(db *sql.DB) func(*Loader) error {
  74. return func(l *Loader) error {
  75. l.db = db
  76. return nil
  77. }
  78. }
  79. // Dialect informs Loader about which database dialect you're using.
  80. //
  81. // Possible options are "postgresql", "timescaledb", "mysql", "mariadb",
  82. // "sqlite" and "sqlserver".
  83. func Dialect(dialect string) func(*Loader) error {
  84. return func(l *Loader) error {
  85. h, err := helperForDialect(dialect)
  86. if err != nil {
  87. return err
  88. }
  89. l.helper = h
  90. return nil
  91. }
  92. }
  93. func helperForDialect(dialect string) (helper, error) {
  94. switch dialect {
  95. case "postgres", "postgresql", "timescaledb":
  96. return &postgreSQL{}, nil
  97. case "mysql", "mariadb":
  98. return &mySQL{}, nil
  99. case "sqlite", "sqlite3":
  100. return &sqlite{}, nil
  101. case "mssql", "sqlserver":
  102. return &sqlserver{}, nil
  103. default:
  104. return nil, fmt.Errorf(`testfixtures: unrecognized dialect "%s"`, dialect)
  105. }
  106. }
  107. // UseAlterConstraint If true, the contraint disabling will do
  108. // using ALTER CONTRAINT sintax, only allowed in PG >= 9.4.
  109. // If false, the constraint disabling will use DISABLE TRIGGER ALL,
  110. // which requires SUPERUSER privileges.
  111. //
  112. // Only valid for PostgreSQL. Returns an error otherwise.
  113. func UseAlterConstraint() func(*Loader) error {
  114. return func(l *Loader) error {
  115. pgHelper, ok := l.helper.(*postgreSQL)
  116. if !ok {
  117. return fmt.Errorf("testfixtures: UseAlterConstraint is only valid for PostgreSQL databases")
  118. }
  119. pgHelper.useAlterConstraint = true
  120. return nil
  121. }
  122. }
  123. // SkipResetSequences prevents Loader from reseting sequences after loading
  124. // fixtures.
  125. //
  126. // Only valid for PostgreSQL. Returns an error otherwise.
  127. func SkipResetSequences() func(*Loader) error {
  128. return func(l *Loader) error {
  129. pgHelper, ok := l.helper.(*postgreSQL)
  130. if !ok {
  131. return fmt.Errorf("testfixtures: SkipResetSequences is only valid for PostgreSQL databases")
  132. }
  133. pgHelper.skipResetSequences = true
  134. return nil
  135. }
  136. }
  137. // ResetSequencesTo sets the value the sequences will be reset to.
  138. //
  139. // Defaults to 10000.
  140. //
  141. // Only valid for PostgreSQL. Returns an error otherwise.
  142. func ResetSequencesTo(value int64) func(*Loader) error {
  143. return func(l *Loader) error {
  144. pgHelper, ok := l.helper.(*postgreSQL)
  145. if !ok {
  146. return fmt.Errorf("testfixtures: ResetSequencesTo is only valid for PostgreSQL databases")
  147. }
  148. pgHelper.resetSequencesTo = value
  149. return nil
  150. }
  151. }
  152. // DangerousSkipTestDatabaseCheck will make Loader not check if the database
  153. // name contains "test". Use with caution!
  154. func DangerousSkipTestDatabaseCheck() func(*Loader) error {
  155. return func(l *Loader) error {
  156. l.skipTestDatabaseCheck = true
  157. return nil
  158. }
  159. }
  160. // Directory informs Loader to load YAML files from a given directory.
  161. func Directory(dir string) func(*Loader) error {
  162. return func(l *Loader) error {
  163. fixtures, err := l.fixturesFromDir(dir)
  164. if err != nil {
  165. return err
  166. }
  167. l.fixturesFiles = append(l.fixturesFiles, fixtures...)
  168. return nil
  169. }
  170. }
  171. // Files informs Loader to load a given set of YAML files.
  172. func Files(files ...string) func(*Loader) error {
  173. return func(l *Loader) error {
  174. fixtures, err := l.fixturesFromFiles(files...)
  175. if err != nil {
  176. return err
  177. }
  178. l.fixturesFiles = append(l.fixturesFiles, fixtures...)
  179. return nil
  180. }
  181. }
  182. // Paths inform Loader to load a given set of YAML files and directories.
  183. func Paths(paths ...string) func(*Loader) error {
  184. return func(l *Loader) error {
  185. fixtures, err := l.fixturesFromPaths(paths...)
  186. if err != nil {
  187. return err
  188. }
  189. l.fixturesFiles = append(l.fixturesFiles, fixtures...)
  190. return nil
  191. }
  192. }
  193. // Location makes Loader use the given location by default when parsing
  194. // dates. If not given, by default it uses the value of time.Local.
  195. func Location(location *time.Location) func(*Loader) error {
  196. return func(l *Loader) error {
  197. l.location = location
  198. return nil
  199. }
  200. }
  201. // Template makes loader process each YAML file as an template using the
  202. // text/template package.
  203. //
  204. // For more information on how templates work in Go please read:
  205. // https://golang.org/pkg/text/template/
  206. //
  207. // If not given the YAML files are parsed as is.
  208. func Template() func(*Loader) error {
  209. return func(l *Loader) error {
  210. l.template = true
  211. return nil
  212. }
  213. }
  214. // TemplateFuncs allow choosing which functions will be available
  215. // when processing templates.
  216. //
  217. // For more information see: https://golang.org/pkg/text/template/#Template.Funcs
  218. func TemplateFuncs(funcs template.FuncMap) func(*Loader) error {
  219. return func(l *Loader) error {
  220. if !l.template {
  221. return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateFuns() option`)
  222. }
  223. l.templateFuncs = funcs
  224. return nil
  225. }
  226. }
  227. // TemplateDelims allow choosing which delimiters will be used for templating.
  228. // This defaults to "{{" and "}}".
  229. //
  230. // For more information see https://golang.org/pkg/text/template/#Template.Delims
  231. func TemplateDelims(left, right string) func(*Loader) error {
  232. return func(l *Loader) error {
  233. if !l.template {
  234. return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateDelims() option`)
  235. }
  236. l.templateLeftDelim = left
  237. l.templateRightDelim = right
  238. return nil
  239. }
  240. }
  241. // TemplateOptions allows you to specific which text/template options will
  242. // be enabled when processing templates.
  243. //
  244. // This defaults to "missingkey=zero". Check the available options here:
  245. // https://golang.org/pkg/text/template/#Template.Option
  246. func TemplateOptions(options ...string) func(*Loader) error {
  247. return func(l *Loader) error {
  248. if !l.template {
  249. return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateOptions() option`)
  250. }
  251. l.templateOptions = options
  252. return nil
  253. }
  254. }
  255. // TemplateData allows you to specify which data will be available
  256. // when processing templates. Data is accesible by prefixing it with a "."
  257. // like {{.MyKey}}.
  258. func TemplateData(data interface{}) func(*Loader) error {
  259. return func(l *Loader) error {
  260. if !l.template {
  261. return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateData() option`)
  262. }
  263. l.templateData = data
  264. return nil
  265. }
  266. }
  267. // EnsureTestDatabase returns an error if the database name does not contains
  268. // "test".
  269. func (l *Loader) EnsureTestDatabase() error {
  270. dbName, err := l.helper.databaseName(l.db)
  271. if err != nil {
  272. return err
  273. }
  274. if !testDatabaseRegexp.MatchString(dbName) {
  275. return fmt.Errorf(`testfixtures: database "%s" does not appear to be a test database`, dbName)
  276. }
  277. return nil
  278. }
  279. // Load wipes and after load all fixtures in the database.
  280. // if err := fixtures.Load(); err != nil {
  281. // ...
  282. // }
  283. func (l *Loader) Load() error {
  284. if !l.skipTestDatabaseCheck {
  285. if err := l.EnsureTestDatabase(); err != nil {
  286. return err
  287. }
  288. }
  289. err := l.helper.disableReferentialIntegrity(l.db, func(tx *sql.Tx) error {
  290. for _, file := range l.fixturesFiles {
  291. modified, err := l.helper.isTableModified(tx, file.fileNameWithoutExtension())
  292. if err != nil {
  293. return err
  294. }
  295. if !modified {
  296. continue
  297. }
  298. if err := file.delete(tx, l.helper); err != nil {
  299. return err
  300. }
  301. err = l.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
  302. for j, i := range file.insertSQLs {
  303. if _, err := tx.Exec(i.sql, i.params...); err != nil {
  304. return &InsertError{
  305. Err: err,
  306. File: file.fileName,
  307. Index: j,
  308. SQL: i.sql,
  309. Params: i.params,
  310. }
  311. }
  312. }
  313. return nil
  314. })
  315. if err != nil {
  316. return err
  317. }
  318. }
  319. return nil
  320. })
  321. if err != nil {
  322. return err
  323. }
  324. return l.helper.afterLoad(l.db)
  325. }
  326. // InsertError will be returned if any error happens on database while
  327. // inserting the record.
  328. type InsertError struct {
  329. Err error
  330. File string
  331. Index int
  332. SQL string
  333. Params []interface{}
  334. }
  335. func (e *InsertError) Error() string {
  336. return fmt.Sprintf(
  337. "testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",
  338. e.Err,
  339. e.File,
  340. e.Index,
  341. e.SQL,
  342. e.Params,
  343. )
  344. }
  345. func (l *Loader) buildInsertSQLs() error {
  346. for _, f := range l.fixturesFiles {
  347. var records interface{}
  348. if err := yaml.Unmarshal(f.content, &records); err != nil {
  349. return fmt.Errorf("testfixtures: could not unmarshal YAML: %w", err)
  350. }
  351. switch records := records.(type) {
  352. case []interface{}:
  353. f.insertSQLs = make([]insertSQL, 0, len(records))
  354. for _, record := range records {
  355. recordMap, ok := record.(map[interface{}]interface{})
  356. if !ok {
  357. return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
  358. }
  359. sql, values, err := l.buildInsertSQL(f, recordMap)
  360. if err != nil {
  361. return err
  362. }
  363. f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
  364. }
  365. case map[interface{}]interface{}:
  366. f.insertSQLs = make([]insertSQL, 0, len(records))
  367. for _, record := range records {
  368. recordMap, ok := record.(map[interface{}]interface{})
  369. if !ok {
  370. return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
  371. }
  372. sql, values, err := l.buildInsertSQL(f, recordMap)
  373. if err != nil {
  374. return err
  375. }
  376. f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
  377. }
  378. default:
  379. return fmt.Errorf("testfixtures: fixture is not a slice or map")
  380. }
  381. }
  382. return nil
  383. }
  384. func (f *fixtureFile) fileNameWithoutExtension() string {
  385. return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1)
  386. }
  387. func (f *fixtureFile) delete(tx *sql.Tx, h helper) error {
  388. if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension()))); err != nil {
  389. return fmt.Errorf(`testfixtures: could not clean table "%s": %w`, f.fileNameWithoutExtension(), err)
  390. }
  391. return nil
  392. }
  393. func (l *Loader) buildInsertSQL(f *fixtureFile, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) {
  394. var (
  395. sqlColumns = make([]string, 0, len(record))
  396. sqlValues = make([]string, 0, len(record))
  397. i = 1
  398. )
  399. for key, value := range record {
  400. keyStr, ok := key.(string)
  401. if !ok {
  402. err = fmt.Errorf("testfixtures: record map key is not a string")
  403. return
  404. }
  405. sqlColumns = append(sqlColumns, l.helper.quoteKeyword(keyStr))
  406. // if string, try convert to SQL or time
  407. // if map or array, convert to json
  408. switch v := value.(type) {
  409. case string:
  410. if strings.HasPrefix(v, "RAW=") {
  411. sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
  412. continue
  413. }
  414. if t, err := l.tryStrToDate(v); err == nil {
  415. value = t
  416. }
  417. case []interface{}, map[interface{}]interface{}:
  418. value = recursiveToJSON(v)
  419. }
  420. switch l.helper.paramType() {
  421. case paramTypeDollar:
  422. sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
  423. case paramTypeQuestion:
  424. sqlValues = append(sqlValues, "?")
  425. case paramTypeAtSign:
  426. sqlValues = append(sqlValues, fmt.Sprintf("@p%d", i))
  427. }
  428. values = append(values, value)
  429. i++
  430. }
  431. sqlStr = fmt.Sprintf(
  432. "INSERT INTO %s (%s) VALUES (%s)",
  433. l.helper.quoteKeyword(f.fileNameWithoutExtension()),
  434. strings.Join(sqlColumns, ", "),
  435. strings.Join(sqlValues, ", "),
  436. )
  437. return
  438. }
  439. func (l *Loader) fixturesFromDir(dir string) ([]*fixtureFile, error) {
  440. fileinfos, err := ioutil.ReadDir(dir)
  441. if err != nil {
  442. return nil, fmt.Errorf(`testfixtures: could not stat directory "%s": %w`, dir, err)
  443. }
  444. files := make([]*fixtureFile, 0, len(fileinfos))
  445. for _, fileinfo := range fileinfos {
  446. fileExt := filepath.Ext(fileinfo.Name())
  447. if !fileinfo.IsDir() && (fileExt == ".yml" || fileExt == ".yaml") {
  448. fixture := &fixtureFile{
  449. path: path.Join(dir, fileinfo.Name()),
  450. fileName: fileinfo.Name(),
  451. }
  452. fixture.content, err = ioutil.ReadFile(fixture.path)
  453. if err != nil {
  454. return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
  455. }
  456. if err := l.processFileTemplate(fixture); err != nil {
  457. return nil, err
  458. }
  459. files = append(files, fixture)
  460. }
  461. }
  462. return files, nil
  463. }
  464. func (l *Loader) fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) {
  465. var (
  466. fixtureFiles = make([]*fixtureFile, 0, len(fileNames))
  467. err error
  468. )
  469. for _, f := range fileNames {
  470. fixture := &fixtureFile{
  471. path: f,
  472. fileName: filepath.Base(f),
  473. }
  474. fixture.content, err = ioutil.ReadFile(fixture.path)
  475. if err != nil {
  476. return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
  477. }
  478. if err := l.processFileTemplate(fixture); err != nil {
  479. return nil, err
  480. }
  481. fixtureFiles = append(fixtureFiles, fixture)
  482. }
  483. return fixtureFiles, nil
  484. }
  485. func (l *Loader) fixturesFromPaths(paths ...string) ([]*fixtureFile, error) {
  486. fixtureExtractor := func(p string, isDir bool) ([]*fixtureFile, error) {
  487. if isDir {
  488. return l.fixturesFromDir(p)
  489. }
  490. return l.fixturesFromFiles(p)
  491. }
  492. var fixtureFiles []*fixtureFile
  493. for _, p := range paths {
  494. f, err := os.Stat(p)
  495. if err != nil {
  496. return nil, fmt.Errorf(`testfixtures: could not stat path "%s": %w`, p, err)
  497. }
  498. fixtures, err := fixtureExtractor(p, f.IsDir())
  499. if err != nil {
  500. return nil, err
  501. }
  502. fixtureFiles = append(fixtureFiles, fixtures...)
  503. }
  504. return fixtureFiles, nil
  505. }
  506. func (l *Loader) processFileTemplate(f *fixtureFile) error {
  507. if !l.template {
  508. return nil
  509. }
  510. t := template.New("").
  511. Funcs(l.templateFuncs).
  512. Delims(l.templateLeftDelim, l.templateRightDelim).
  513. Option(l.templateOptions...)
  514. t, err := t.Parse(string(f.content))
  515. if err != nil {
  516. return fmt.Errorf(`textfixtures: error on parsing template in %s: %w`, f.fileName, err)
  517. }
  518. var buffer bytes.Buffer
  519. if err := t.Execute(&buffer, l.templateData); err != nil {
  520. return fmt.Errorf(`textfixtures: error on executing template in %s: %w`, f.fileName, err)
  521. }
  522. f.content = buffer.Bytes()
  523. return nil
  524. }