diff options
Diffstat (limited to 'vendor/gopkg.in/testfixtures.v2/testfixtures.go')
-rw-r--r-- | vendor/gopkg.in/testfixtures.v2/testfixtures.go | 90 |
1 files changed, 58 insertions, 32 deletions
diff --git a/vendor/gopkg.in/testfixtures.v2/testfixtures.go b/vendor/gopkg.in/testfixtures.v2/testfixtures.go index bd95580f2a..dfc59c1efc 100644 --- a/vendor/gopkg.in/testfixtures.v2/testfixtures.go +++ b/vendor/gopkg.in/testfixtures.v2/testfixtures.go @@ -2,7 +2,6 @@ package testfixtures import ( "database/sql" - "errors" "fmt" "io/ioutil" "path" @@ -33,22 +32,10 @@ type insertSQL struct { } var ( - // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{} - ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}") - - // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map. - ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map") - - // ErrKeyIsNotString is returned when a record is not of type string - ErrKeyIsNotString = errors.New("Record map key is not string") - - // ErrNotTestDatabase is returned when the database name doesn't contains "test" - ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`) - dbnameRegexp = regexp.MustCompile("(?i)test") ) -// NewFolder craetes a context for all fixtures in a given folder into the database: +// NewFolder creates a context for all fixtures in a given folder into the database: // NewFolder(db, &PostgreSQL{}, "my/fixtures/folder") func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) { fixtures, err := fixturesFromFolder(folderName) @@ -64,7 +51,7 @@ func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) { return c, nil } -// NewFiles craetes a context for all specified fixtures files into database: +// NewFiles creates a context for all specified fixtures files into database: // NewFiles(db, &PostgreSQL{}, // "fixtures/customers.yml", // "fixtures/orders.yml" @@ -102,27 +89,55 @@ func newContext(db *sql.DB, helper Helper, fixtures []*fixtureFile) (*Context, e return c, nil } +// DetectTestDatabase returns nil if databaseName matches regexp +// if err := fixtures.DetectTestDatabase(); err != nil { +// log.Fatal(err) +// } +func (c *Context) DetectTestDatabase() error { + dbName, err := c.helper.databaseName(c.db) + if err != nil { + return err + } + if !dbnameRegexp.MatchString(dbName) { + return ErrNotTestDatabase + } + return nil +} + // Load wipes and after load all fixtures in the database. // if err := fixtures.Load(); err != nil { // log.Fatal(err) // } func (c *Context) Load() error { if !skipDatabaseNameCheck { - if !dbnameRegexp.MatchString(c.helper.databaseName(c.db)) { - return ErrNotTestDatabase + if err := c.DetectTestDatabase(); err != nil { + return err } } err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error { for _, file := range c.fixturesFiles { + modified, err := c.helper.isTableModified(tx, file.fileNameWithoutExtension()) + if err != nil { + return err + } + if !modified { + continue + } if err := file.delete(tx, c.helper); err != nil { return err } - err := c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error { - for _, i := range file.insertSQLs { + err = c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error { + for j, i := range file.insertSQLs { if _, err := tx.Exec(i.sql, i.params...); err != nil { - return err + return &InsertError{ + Err: err, + File: file.fileName, + Index: j, + SQL: i.sql, + Params: i.params, + } } } return nil @@ -133,7 +148,10 @@ func (c *Context) Load() error { } return nil }) - return err + if err != nil { + return err + } + return c.helper.afterLoad(c.db) } func (c *Context) buildInsertSQLs() error { @@ -204,25 +222,33 @@ func (f *fixtureFile) buildInsertSQL(h Helper, record map[interface{}]interface{ sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr)) + // if string, try convert to SQL or time + // if map or array, convert to json + switch v := value.(type) { + case string: + if strings.HasPrefix(v, "RAW=") { + sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW=")) + continue + } + + if t, err := tryStrToDate(v); err == nil { + value = t + } + case []interface{}, map[interface{}]interface{}: + value = recursiveToJSON(v) + } + switch h.paramType() { case paramTypeDollar: sqlValues = append(sqlValues, fmt.Sprintf("$%d", i)) case paramTypeQuestion: sqlValues = append(sqlValues, "?") case paramTypeColon: - switch { - case isDateTime(value): - sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD HH24:MI:SS')", i)) - case isDate(value): - sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD')", i)) - case isTime(value): - sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'HH24:MI:SS')", i)) - default: - sqlValues = append(sqlValues, fmt.Sprintf(":%d", i)) - } + sqlValues = append(sqlValues, fmt.Sprintf(":%d", i)) } - i++ + values = append(values, value) + i++ } sqlStr = fmt.Sprintf( |