123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599 |
- package testfixtures // import "github.com/go-testfixtures/testfixtures/v3"
-
- import (
- "bytes"
- "database/sql"
- "fmt"
- "io/ioutil"
- "os"
- "path"
- "path/filepath"
- "regexp"
- "strings"
- "text/template"
- "time"
-
- "gopkg.in/yaml.v2"
- )
-
- // Loader is the responsible to loading fixtures.
- type Loader struct {
- db *sql.DB
- helper helper
- fixturesFiles []*fixtureFile
-
- skipTestDatabaseCheck bool
- location *time.Location
-
- template bool
- templateFuncs template.FuncMap
- templateLeftDelim string
- templateRightDelim string
- templateOptions []string
- templateData interface{}
- }
-
- type fixtureFile struct {
- path string
- fileName string
- content []byte
- insertSQLs []insertSQL
- }
-
- type insertSQL struct {
- sql string
- params []interface{}
- }
-
- var (
- testDatabaseRegexp = regexp.MustCompile("(?i)test")
-
- errDatabaseIsRequired = fmt.Errorf("testfixtures: database is required")
- errDialectIsRequired = fmt.Errorf("testfixtures: dialect is required")
- )
-
- // New instantiates a new Loader instance. The "Database" and "Driver"
- // options are required.
- func New(options ...func(*Loader) error) (*Loader, error) {
- l := &Loader{
- templateLeftDelim: "{{",
- templateRightDelim: "}}",
- templateOptions: []string{"missingkey=zero"},
- }
-
- for _, option := range options {
- if err := option(l); err != nil {
- return nil, err
- }
- }
-
- if l.db == nil {
- return nil, errDatabaseIsRequired
- }
- if l.helper == nil {
- return nil, errDialectIsRequired
- }
-
- if err := l.helper.init(l.db); err != nil {
- return nil, err
- }
- if err := l.buildInsertSQLs(); err != nil {
- return nil, err
- }
-
- return l, nil
- }
-
- // Database sets an existing sql.DB instant to Loader.
- func Database(db *sql.DB) func(*Loader) error {
- return func(l *Loader) error {
- l.db = db
- return nil
- }
- }
-
- // Dialect informs Loader about which database dialect you're using.
- //
- // Possible options are "postgresql", "timescaledb", "mysql", "mariadb",
- // "sqlite" and "sqlserver".
- func Dialect(dialect string) func(*Loader) error {
- return func(l *Loader) error {
- h, err := helperForDialect(dialect)
- if err != nil {
- return err
- }
- l.helper = h
- return nil
- }
- }
-
- func helperForDialect(dialect string) (helper, error) {
- switch dialect {
- case "postgres", "postgresql", "timescaledb":
- return &postgreSQL{}, nil
- case "mysql", "mariadb":
- return &mySQL{}, nil
- case "sqlite", "sqlite3":
- return &sqlite{}, nil
- case "mssql", "sqlserver":
- return &sqlserver{}, nil
- default:
- return nil, fmt.Errorf(`testfixtures: unrecognized dialect "%s"`, dialect)
- }
- }
-
- // UseAlterConstraint If true, the contraint disabling will do
- // using ALTER CONTRAINT sintax, only allowed in PG >= 9.4.
- // If false, the constraint disabling will use DISABLE TRIGGER ALL,
- // which requires SUPERUSER privileges.
- //
- // Only valid for PostgreSQL. Returns an error otherwise.
- func UseAlterConstraint() func(*Loader) error {
- return func(l *Loader) error {
- pgHelper, ok := l.helper.(*postgreSQL)
- if !ok {
- return fmt.Errorf("testfixtures: UseAlterConstraint is only valid for PostgreSQL databases")
- }
- pgHelper.useAlterConstraint = true
- return nil
- }
- }
-
- // SkipResetSequences prevents Loader from reseting sequences after loading
- // fixtures.
- //
- // Only valid for PostgreSQL. Returns an error otherwise.
- func SkipResetSequences() func(*Loader) error {
- return func(l *Loader) error {
- pgHelper, ok := l.helper.(*postgreSQL)
- if !ok {
- return fmt.Errorf("testfixtures: SkipResetSequences is only valid for PostgreSQL databases")
- }
- pgHelper.skipResetSequences = true
- return nil
- }
- }
-
- // ResetSequencesTo sets the value the sequences will be reset to.
- //
- // Defaults to 10000.
- //
- // Only valid for PostgreSQL. Returns an error otherwise.
- func ResetSequencesTo(value int64) func(*Loader) error {
- return func(l *Loader) error {
- pgHelper, ok := l.helper.(*postgreSQL)
- if !ok {
- return fmt.Errorf("testfixtures: ResetSequencesTo is only valid for PostgreSQL databases")
- }
- pgHelper.resetSequencesTo = value
- return nil
- }
- }
-
- // DangerousSkipTestDatabaseCheck will make Loader not check if the database
- // name contains "test". Use with caution!
- func DangerousSkipTestDatabaseCheck() func(*Loader) error {
- return func(l *Loader) error {
- l.skipTestDatabaseCheck = true
- return nil
- }
- }
-
- // Directory informs Loader to load YAML files from a given directory.
- func Directory(dir string) func(*Loader) error {
- return func(l *Loader) error {
- fixtures, err := l.fixturesFromDir(dir)
- if err != nil {
- return err
- }
- l.fixturesFiles = append(l.fixturesFiles, fixtures...)
- return nil
- }
- }
-
- // Files informs Loader to load a given set of YAML files.
- func Files(files ...string) func(*Loader) error {
- return func(l *Loader) error {
- fixtures, err := l.fixturesFromFiles(files...)
- if err != nil {
- return err
- }
- l.fixturesFiles = append(l.fixturesFiles, fixtures...)
- return nil
- }
- }
-
- // Paths inform Loader to load a given set of YAML files and directories.
- func Paths(paths ...string) func(*Loader) error {
- return func(l *Loader) error {
- fixtures, err := l.fixturesFromPaths(paths...)
- if err != nil {
- return err
- }
- l.fixturesFiles = append(l.fixturesFiles, fixtures...)
- return nil
- }
- }
-
- // Location makes Loader use the given location by default when parsing
- // dates. If not given, by default it uses the value of time.Local.
- func Location(location *time.Location) func(*Loader) error {
- return func(l *Loader) error {
- l.location = location
- return nil
- }
- }
-
- // Template makes loader process each YAML file as an template using the
- // text/template package.
- //
- // For more information on how templates work in Go please read:
- // https://golang.org/pkg/text/template/
- //
- // If not given the YAML files are parsed as is.
- func Template() func(*Loader) error {
- return func(l *Loader) error {
- l.template = true
- return nil
- }
- }
-
- // TemplateFuncs allow choosing which functions will be available
- // when processing templates.
- //
- // For more information see: https://golang.org/pkg/text/template/#Template.Funcs
- func TemplateFuncs(funcs template.FuncMap) func(*Loader) error {
- return func(l *Loader) error {
- if !l.template {
- return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateFuns() option`)
- }
-
- l.templateFuncs = funcs
- return nil
- }
- }
-
- // TemplateDelims allow choosing which delimiters will be used for templating.
- // This defaults to "{{" and "}}".
- //
- // For more information see https://golang.org/pkg/text/template/#Template.Delims
- func TemplateDelims(left, right string) func(*Loader) error {
- return func(l *Loader) error {
- if !l.template {
- return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateDelims() option`)
- }
-
- l.templateLeftDelim = left
- l.templateRightDelim = right
- return nil
- }
- }
-
- // TemplateOptions allows you to specific which text/template options will
- // be enabled when processing templates.
- //
- // This defaults to "missingkey=zero". Check the available options here:
- // https://golang.org/pkg/text/template/#Template.Option
- func TemplateOptions(options ...string) func(*Loader) error {
- return func(l *Loader) error {
- if !l.template {
- return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateOptions() option`)
- }
-
- l.templateOptions = options
- return nil
- }
- }
-
- // TemplateData allows you to specify which data will be available
- // when processing templates. Data is accesible by prefixing it with a "."
- // like {{.MyKey}}.
- func TemplateData(data interface{}) func(*Loader) error {
- return func(l *Loader) error {
- if !l.template {
- return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateData() option`)
- }
-
- l.templateData = data
- return nil
- }
- }
-
- // EnsureTestDatabase returns an error if the database name does not contains
- // "test".
- func (l *Loader) EnsureTestDatabase() error {
- dbName, err := l.helper.databaseName(l.db)
- if err != nil {
- return err
- }
- if !testDatabaseRegexp.MatchString(dbName) {
- return fmt.Errorf(`testfixtures: database "%s" does not appear to be a test database`, dbName)
- }
- return nil
- }
-
- // Load wipes and after load all fixtures in the database.
- // if err := fixtures.Load(); err != nil {
- // ...
- // }
- func (l *Loader) Load() error {
- if !l.skipTestDatabaseCheck {
- if err := l.EnsureTestDatabase(); err != nil {
- return err
- }
- }
-
- err := l.helper.disableReferentialIntegrity(l.db, func(tx *sql.Tx) error {
- for _, file := range l.fixturesFiles {
- modified, err := l.helper.isTableModified(tx, file.fileNameWithoutExtension())
- if err != nil {
- return err
- }
- if !modified {
- continue
- }
- if err := file.delete(tx, l.helper); err != nil {
- return err
- }
-
- err = l.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
- for j, i := range file.insertSQLs {
- if _, err := tx.Exec(i.sql, i.params...); err != nil {
- return &InsertError{
- Err: err,
- File: file.fileName,
- Index: j,
- SQL: i.sql,
- Params: i.params,
- }
- }
- }
- return nil
- })
- if err != nil {
- return err
- }
- }
- return nil
- })
- if err != nil {
- return err
- }
- return l.helper.afterLoad(l.db)
- }
-
- // InsertError will be returned if any error happens on database while
- // inserting the record.
- type InsertError struct {
- Err error
- File string
- Index int
- SQL string
- Params []interface{}
- }
-
- func (e *InsertError) Error() string {
- return fmt.Sprintf(
- "testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",
- e.Err,
- e.File,
- e.Index,
- e.SQL,
- e.Params,
- )
- }
-
- func (l *Loader) buildInsertSQLs() error {
- for _, f := range l.fixturesFiles {
- var records interface{}
- if err := yaml.Unmarshal(f.content, &records); err != nil {
- return fmt.Errorf("testfixtures: could not unmarshal YAML: %w", err)
- }
-
- switch records := records.(type) {
- case []interface{}:
- f.insertSQLs = make([]insertSQL, 0, len(records))
-
- for _, record := range records {
- recordMap, ok := record.(map[interface{}]interface{})
- if !ok {
- return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
- }
-
- sql, values, err := l.buildInsertSQL(f, recordMap)
- if err != nil {
- return err
- }
-
- f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
- }
- case map[interface{}]interface{}:
- f.insertSQLs = make([]insertSQL, 0, len(records))
-
- for _, record := range records {
- recordMap, ok := record.(map[interface{}]interface{})
- if !ok {
- return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")
- }
-
- sql, values, err := l.buildInsertSQL(f, recordMap)
- if err != nil {
- return err
- }
-
- f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
- }
- default:
- return fmt.Errorf("testfixtures: fixture is not a slice or map")
- }
- }
-
- return nil
- }
-
- func (f *fixtureFile) fileNameWithoutExtension() string {
- return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1)
- }
-
- func (f *fixtureFile) delete(tx *sql.Tx, h helper) error {
- if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension()))); err != nil {
- return fmt.Errorf(`testfixtures: could not clean table "%s": %w`, f.fileNameWithoutExtension(), err)
- }
- return nil
- }
-
- func (l *Loader) buildInsertSQL(f *fixtureFile, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) {
- var (
- sqlColumns = make([]string, 0, len(record))
- sqlValues = make([]string, 0, len(record))
- i = 1
- )
- for key, value := range record {
- keyStr, ok := key.(string)
- if !ok {
- err = fmt.Errorf("testfixtures: record map key is not a string")
- return
- }
-
- sqlColumns = append(sqlColumns, l.helper.quoteKeyword(keyStr))
-
- // if string, try convert to SQL or time
- // if map or array, convert to json
- switch v := value.(type) {
- case string:
- if strings.HasPrefix(v, "RAW=") {
- sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
- continue
- }
-
- if t, err := l.tryStrToDate(v); err == nil {
- value = t
- }
- case []interface{}, map[interface{}]interface{}:
- value = recursiveToJSON(v)
- }
-
- switch l.helper.paramType() {
- case paramTypeDollar:
- sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
- case paramTypeQuestion:
- sqlValues = append(sqlValues, "?")
- case paramTypeAtSign:
- sqlValues = append(sqlValues, fmt.Sprintf("@p%d", i))
- }
-
- values = append(values, value)
- i++
- }
-
- sqlStr = fmt.Sprintf(
- "INSERT INTO %s (%s) VALUES (%s)",
- l.helper.quoteKeyword(f.fileNameWithoutExtension()),
- strings.Join(sqlColumns, ", "),
- strings.Join(sqlValues, ", "),
- )
- return
- }
-
- func (l *Loader) fixturesFromDir(dir string) ([]*fixtureFile, error) {
- fileinfos, err := ioutil.ReadDir(dir)
- if err != nil {
- return nil, fmt.Errorf(`testfixtures: could not stat directory "%s": %w`, dir, err)
- }
-
- files := make([]*fixtureFile, 0, len(fileinfos))
-
- for _, fileinfo := range fileinfos {
- fileExt := filepath.Ext(fileinfo.Name())
- if !fileinfo.IsDir() && (fileExt == ".yml" || fileExt == ".yaml") {
- fixture := &fixtureFile{
- path: path.Join(dir, fileinfo.Name()),
- fileName: fileinfo.Name(),
- }
- fixture.content, err = ioutil.ReadFile(fixture.path)
- if err != nil {
- return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
- }
- if err := l.processFileTemplate(fixture); err != nil {
- return nil, err
- }
- files = append(files, fixture)
- }
- }
- return files, nil
- }
-
- func (l *Loader) fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) {
- var (
- fixtureFiles = make([]*fixtureFile, 0, len(fileNames))
- err error
- )
-
- for _, f := range fileNames {
- fixture := &fixtureFile{
- path: f,
- fileName: filepath.Base(f),
- }
- fixture.content, err = ioutil.ReadFile(fixture.path)
- if err != nil {
- return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
- }
- if err := l.processFileTemplate(fixture); err != nil {
- return nil, err
- }
- fixtureFiles = append(fixtureFiles, fixture)
- }
-
- return fixtureFiles, nil
- }
-
- func (l *Loader) fixturesFromPaths(paths ...string) ([]*fixtureFile, error) {
- fixtureExtractor := func(p string, isDir bool) ([]*fixtureFile, error) {
- if isDir {
- return l.fixturesFromDir(p)
- }
-
- return l.fixturesFromFiles(p)
- }
-
- var fixtureFiles []*fixtureFile
-
- for _, p := range paths {
- f, err := os.Stat(p)
- if err != nil {
- return nil, fmt.Errorf(`testfixtures: could not stat path "%s": %w`, p, err)
- }
-
- fixtures, err := fixtureExtractor(p, f.IsDir())
- if err != nil {
- return nil, err
- }
-
- fixtureFiles = append(fixtureFiles, fixtures...)
- }
-
- return fixtureFiles, nil
- }
-
- func (l *Loader) processFileTemplate(f *fixtureFile) error {
- if !l.template {
- return nil
- }
-
- t := template.New("").
- Funcs(l.templateFuncs).
- Delims(l.templateLeftDelim, l.templateRightDelim).
- Option(l.templateOptions...)
- t, err := t.Parse(string(f.content))
- if err != nil {
- return fmt.Errorf(`textfixtures: error on parsing template in %s: %w`, f.fileName, err)
- }
-
- var buffer bytes.Buffer
- if err := t.Execute(&buffer, l.templateData); err != nil {
- return fmt.Errorf(`textfixtures: error on executing template in %s: %w`, f.fileName, err)
- }
-
- f.content = buffer.Bytes()
- return nil
- }
|