aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/gopkg.in/testfixtures.v2/testfixtures.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/gopkg.in/testfixtures.v2/testfixtures.go')
-rw-r--r--vendor/gopkg.in/testfixtures.v2/testfixtures.go90
1 files changed, 58 insertions, 32 deletions
diff --git a/vendor/gopkg.in/testfixtures.v2/testfixtures.go b/vendor/gopkg.in/testfixtures.v2/testfixtures.go
index bd95580f2a..dfc59c1efc 100644
--- a/vendor/gopkg.in/testfixtures.v2/testfixtures.go
+++ b/vendor/gopkg.in/testfixtures.v2/testfixtures.go
@@ -2,7 +2,6 @@ package testfixtures
import (
"database/sql"
- "errors"
"fmt"
"io/ioutil"
"path"
@@ -33,22 +32,10 @@ type insertSQL struct {
}
var (
- // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
- ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
-
- // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
- ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
-
- // ErrKeyIsNotString is returned when a record is not of type string
- ErrKeyIsNotString = errors.New("Record map key is not string")
-
- // ErrNotTestDatabase is returned when the database name doesn't contains "test"
- ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
-
dbnameRegexp = regexp.MustCompile("(?i)test")
)
-// NewFolder craetes a context for all fixtures in a given folder into the database:
+// NewFolder creates a context for all fixtures in a given folder into the database:
// NewFolder(db, &PostgreSQL{}, "my/fixtures/folder")
func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
fixtures, err := fixturesFromFolder(folderName)
@@ -64,7 +51,7 @@ func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
return c, nil
}
-// NewFiles craetes a context for all specified fixtures files into database:
+// NewFiles creates a context for all specified fixtures files into database:
// NewFiles(db, &PostgreSQL{},
// "fixtures/customers.yml",
// "fixtures/orders.yml"
@@ -102,27 +89,55 @@ func newContext(db *sql.DB, helper Helper, fixtures []*fixtureFile) (*Context, e
return c, nil
}
+// DetectTestDatabase returns nil if databaseName matches regexp
+// if err := fixtures.DetectTestDatabase(); err != nil {
+// log.Fatal(err)
+// }
+func (c *Context) DetectTestDatabase() error {
+ dbName, err := c.helper.databaseName(c.db)
+ if err != nil {
+ return err
+ }
+ if !dbnameRegexp.MatchString(dbName) {
+ return ErrNotTestDatabase
+ }
+ return nil
+}
+
// Load wipes and after load all fixtures in the database.
// if err := fixtures.Load(); err != nil {
// log.Fatal(err)
// }
func (c *Context) Load() error {
if !skipDatabaseNameCheck {
- if !dbnameRegexp.MatchString(c.helper.databaseName(c.db)) {
- return ErrNotTestDatabase
+ if err := c.DetectTestDatabase(); err != nil {
+ return err
}
}
err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error {
for _, file := range c.fixturesFiles {
+ modified, err := c.helper.isTableModified(tx, file.fileNameWithoutExtension())
+ if err != nil {
+ return err
+ }
+ if !modified {
+ continue
+ }
if err := file.delete(tx, c.helper); err != nil {
return err
}
- err := c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
- for _, i := range file.insertSQLs {
+ err = c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
+ for j, i := range file.insertSQLs {
if _, err := tx.Exec(i.sql, i.params...); err != nil {
- return err
+ return &InsertError{
+ Err: err,
+ File: file.fileName,
+ Index: j,
+ SQL: i.sql,
+ Params: i.params,
+ }
}
}
return nil
@@ -133,7 +148,10 @@ func (c *Context) Load() error {
}
return nil
})
- return err
+ if err != nil {
+ return err
+ }
+ return c.helper.afterLoad(c.db)
}
func (c *Context) buildInsertSQLs() error {
@@ -204,25 +222,33 @@ func (f *fixtureFile) buildInsertSQL(h Helper, record map[interface{}]interface{
sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr))
+ // if string, try convert to SQL or time
+ // if map or array, convert to json
+ switch v := value.(type) {
+ case string:
+ if strings.HasPrefix(v, "RAW=") {
+ sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
+ continue
+ }
+
+ if t, err := tryStrToDate(v); err == nil {
+ value = t
+ }
+ case []interface{}, map[interface{}]interface{}:
+ value = recursiveToJSON(v)
+ }
+
switch h.paramType() {
case paramTypeDollar:
sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
case paramTypeQuestion:
sqlValues = append(sqlValues, "?")
case paramTypeColon:
- switch {
- case isDateTime(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD HH24:MI:SS')", i))
- case isDate(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD')", i))
- case isTime(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'HH24:MI:SS')", i))
- default:
- sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
- }
+ sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
}
- i++
+
values = append(values, value)
+ i++
}
sqlStr = fmt.Sprintf(