summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMura Li <typeless@users.noreply.github.com>2018-10-03 03:20:02 +0800
committertechknowlogick <hello@techknowlogick.com>2018-10-02 15:20:02 -0400
commitdba955be7c92c1c8b967add8a6d362ea7ec2da67 (patch)
tree221deeaf67f594fe67db7e2124a75a8d16d30d31
parentb8d048fa0d2cdd0781051990d34cb3681a8ae114 (diff)
downloadgitea-dba955be7c92c1c8b967add8a6d362ea7ec2da67.tar.gz
gitea-dba955be7c92c1c8b967add8a6d362ea7ec2da67.zip
Upgrade gopkg.in/testfixtures.v2 (#4999)
-rw-r--r--Gopkg.lock6
-rw-r--r--vendor/gopkg.in/testfixtures.v2/deprecated.go28
-rw-r--r--vendor/gopkg.in/testfixtures.v2/errors.go41
-rw-r--r--vendor/gopkg.in/testfixtures.v2/generate.go110
-rw-r--r--vendor/gopkg.in/testfixtures.v2/helper.go43
-rw-r--r--vendor/gopkg.in/testfixtures.v2/json.go44
-rw-r--r--vendor/gopkg.in/testfixtures.v2/mysql.go101
-rw-r--r--vendor/gopkg.in/testfixtures.v2/oracle.go83
-rw-r--r--vendor/gopkg.in/testfixtures.v2/postgresql.go150
-rw-r--r--vendor/gopkg.in/testfixtures.v2/sqlite.go50
-rw-r--r--vendor/gopkg.in/testfixtures.v2/sqlserver.go75
-rw-r--r--vendor/gopkg.in/testfixtures.v2/testfixtures.go90
-rw-r--r--vendor/gopkg.in/testfixtures.v2/time.go52
13 files changed, 705 insertions, 168 deletions
diff --git a/Gopkg.lock b/Gopkg.lock
index e6eb721fd3..39b44c65dc 100644
--- a/Gopkg.lock
+++ b/Gopkg.lock
@@ -885,12 +885,12 @@
version = "v2.3.2"
[[projects]]
- digest = "1:e2144032dcf8e856fb733151391669dc2d32d79ebb69e3c4151efa605f5e8a01"
+ digest = "1:9c541fc507676a69ea8aaed1af53278a5241d26ce0f192c993fec2ac5b78f795"
name = "gopkg.in/testfixtures.v2"
packages = ["."]
pruneopts = "NUT"
- revision = "b9ef14dc461bf934d8df2dfc6f1f456be5664cca"
- version = "v2.0.0"
+ revision = "fa3fb89109b0b31957a5430cef3e93e535de362b"
+ version = "v2.5.0"
[[projects]]
digest = "1:ad6f94355d292690137613735965bd3688844880fdab90eccf66321910344942"
diff --git a/vendor/gopkg.in/testfixtures.v2/deprecated.go b/vendor/gopkg.in/testfixtures.v2/deprecated.go
index b83eeef436..16e0969e33 100644
--- a/vendor/gopkg.in/testfixtures.v2/deprecated.go
+++ b/vendor/gopkg.in/testfixtures.v2/deprecated.go
@@ -5,22 +5,38 @@ import (
)
type (
- DataBaseHelper Helper // Deprecated: Use Helper instead
+ // DataBaseHelper is the helper interface
+ // Deprecated: Use Helper instead
+ DataBaseHelper Helper
- PostgreSQLHelper struct { // Deprecated: Use PostgreSQL{} instead
+ // PostgreSQLHelper is the PostgreSQL helper
+ // Deprecated: Use PostgreSQL{} instead
+ PostgreSQLHelper struct {
PostgreSQL
UseAlterConstraint bool
}
- MySQLHelper struct { // Deprecated: Use MySQL{} instead
+
+ // MySQLHelper is the MySQL helper
+ // Deprecated: Use MySQL{} instead
+ MySQLHelper struct {
MySQL
}
- SQLiteHelper struct { // Deprecated: Use SQLite{} instead
+
+ // SQLiteHelper is the SQLite helper
+ // Deprecated: Use SQLite{} instead
+ SQLiteHelper struct {
SQLite
}
- SQLServerHelper struct { // Deprecated: Use SQLServer{} instead
+
+ // SQLServerHelper is the SQLServer helper
+ // Deprecated: Use SQLServer{} instead
+ SQLServerHelper struct {
SQLServer
}
- OracleHelper struct { // Deprecated: Use Oracle{} instead
+
+ // OracleHelper is the Oracle helper
+ // Deprecated: Use Oracle{} instead
+ OracleHelper struct {
Oracle
}
)
diff --git a/vendor/gopkg.in/testfixtures.v2/errors.go b/vendor/gopkg.in/testfixtures.v2/errors.go
new file mode 100644
index 0000000000..17eb284c6c
--- /dev/null
+++ b/vendor/gopkg.in/testfixtures.v2/errors.go
@@ -0,0 +1,41 @@
+package testfixtures
+
+import (
+ "errors"
+ "fmt"
+)
+
+var (
+ // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
+ ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
+
+ // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
+ ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
+
+ // ErrKeyIsNotString is returned when a record is not of type string
+ ErrKeyIsNotString = errors.New("Record map key is not string")
+
+ // ErrNotTestDatabase is returned when the database name doesn't contains "test"
+ ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
+)
+
+// InsertError will be returned if any error happens on database while
+// inserting the record
+type InsertError struct {
+ Err error
+ File string
+ Index int
+ SQL string
+ Params []interface{}
+}
+
+func (e *InsertError) Error() string {
+ return fmt.Sprintf(
+ "testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",
+ e.Err,
+ e.File,
+ e.Index,
+ e.SQL,
+ e.Params,
+ )
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/generate.go b/vendor/gopkg.in/testfixtures.v2/generate.go
new file mode 100644
index 0000000000..844814007c
--- /dev/null
+++ b/vendor/gopkg.in/testfixtures.v2/generate.go
@@ -0,0 +1,110 @@
+package testfixtures
+
+import (
+ "database/sql"
+ "fmt"
+ "os"
+ "path"
+ "unicode/utf8"
+
+ "gopkg.in/yaml.v2"
+)
+
+// TableInfo is settings for generating a fixture for table.
+type TableInfo struct {
+ Name string // Table name
+ Where string // A condition for extracting records. If this value is empty, extracts all records.
+}
+
+func (ti *TableInfo) whereClause() string {
+ if ti.Where == "" {
+ return ""
+ }
+ return fmt.Sprintf(" WHERE %s", ti.Where)
+}
+
+// GenerateFixtures generates fixtures for the current contents of a database, and saves
+// them to the specified directory
+func GenerateFixtures(db *sql.DB, helper Helper, dir string) error {
+ tables, err := helper.tableNames(db)
+ if err != nil {
+ return err
+ }
+ for _, table := range tables {
+ filename := path.Join(dir, table+".yml")
+ if err := generateFixturesForTable(db, helper, &TableInfo{Name: table}, filename); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// GenerateFixturesForTables generates fixtures for the current contents of specified tables in a database, and saves
+// them to the specified directory
+func GenerateFixturesForTables(db *sql.DB, tables []*TableInfo, helper Helper, dir string) error {
+ for _, table := range tables {
+ filename := path.Join(dir, table.Name+".yml")
+ if err := generateFixturesForTable(db, helper, table, filename); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func generateFixturesForTable(db *sql.DB, h Helper, table *TableInfo, filename string) error {
+ query := fmt.Sprintf("SELECT * FROM %s%s", h.quoteKeyword(table.Name), table.whereClause())
+ rows, err := db.Query(query)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ columns, err := rows.Columns()
+ if err != nil {
+ return err
+ }
+
+ fixtures := make([]interface{}, 0, 10)
+ for rows.Next() {
+ entries := make([]interface{}, len(columns))
+ entryPtrs := make([]interface{}, len(entries))
+ for i := range entries {
+ entryPtrs[i] = &entries[i]
+ }
+ if err := rows.Scan(entryPtrs...); err != nil {
+ return err
+ }
+
+ entryMap := make(map[string]interface{}, len(entries))
+ for i, column := range columns {
+ entryMap[column] = convertValue(entries[i])
+ }
+ fixtures = append(fixtures, entryMap)
+ }
+ if err = rows.Err(); err != nil {
+ return err
+ }
+
+ f, err := os.Create(filename)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ marshaled, err := yaml.Marshal(fixtures)
+ if err != nil {
+ return err
+ }
+ _, err = f.Write(marshaled)
+ return err
+}
+
+func convertValue(value interface{}) interface{} {
+ switch v := value.(type) {
+ case []byte:
+ if utf8.Valid(v) {
+ return string(v)
+ }
+ }
+ return value
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/helper.go b/vendor/gopkg.in/testfixtures.v2/helper.go
index f1c19f29d4..bd1ebba62e 100644
--- a/vendor/gopkg.in/testfixtures.v2/helper.go
+++ b/vendor/gopkg.in/testfixtures.v2/helper.go
@@ -18,21 +18,56 @@ type Helper interface {
init(*sql.DB) error
disableReferentialIntegrity(*sql.DB, loadFunction) error
paramType() int
- databaseName(*sql.DB) string
+ databaseName(queryable) (string, error)
+ tableNames(queryable) ([]string, error)
+ isTableModified(queryable, string) (bool, error)
+ afterLoad(queryable) error
quoteKeyword(string) string
whileInsertOnTable(*sql.Tx, string, func() error) error
}
+type queryable interface {
+ Exec(string, ...interface{}) (sql.Result, error)
+ Query(string, ...interface{}) (*sql.Rows, error)
+ QueryRow(string, ...interface{}) *sql.Row
+}
+
+// batchSplitter is an interface with method which returns byte slice for
+// splitting SQL batches. This need to split sql statements and run its
+// separately.
+//
+// For Microsoft SQL Server batch splitter is "GO". For details see
+// https://docs.microsoft.com/en-us/sql/t-sql/language-elements/sql-server-utilities-statements-go
+type batchSplitter interface {
+ splitter() []byte
+}
+
+var (
+ _ Helper = &MySQL{}
+ _ Helper = &PostgreSQL{}
+ _ Helper = &SQLite{}
+ _ Helper = &Oracle{}
+ _ Helper = &SQLServer{}
+)
+
type baseHelper struct{}
-func (*baseHelper) init(_ *sql.DB) error {
+func (baseHelper) init(_ *sql.DB) error {
return nil
}
-func (*baseHelper) quoteKeyword(str string) string {
+func (baseHelper) quoteKeyword(str string) string {
return fmt.Sprintf(`"%s"`, str)
}
-func (*baseHelper) whileInsertOnTable(_ *sql.Tx, _ string, fn func() error) error {
+func (baseHelper) whileInsertOnTable(_ *sql.Tx, _ string, fn func() error) error {
return fn()
}
+
+func (baseHelper) isTableModified(_ queryable, _ string) (bool, error) {
+ return true, nil
+}
+
+func (baseHelper) afterLoad(_ queryable) error {
+ return nil
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/json.go b/vendor/gopkg.in/testfixtures.v2/json.go
new file mode 100644
index 0000000000..f954a17a7f
--- /dev/null
+++ b/vendor/gopkg.in/testfixtures.v2/json.go
@@ -0,0 +1,44 @@
+package testfixtures
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+)
+
+var (
+ _ driver.Valuer = jsonArray{}
+ _ driver.Valuer = jsonMap{}
+)
+
+type jsonArray []interface{}
+
+func (a jsonArray) Value() (driver.Value, error) {
+ return json.Marshal(a)
+}
+
+type jsonMap map[string]interface{}
+
+func (m jsonMap) Value() (driver.Value, error) {
+ return json.Marshal(m)
+}
+
+// Go refuses to convert map[interface{}]interface{} to JSON because JSON only support string keys
+// So it's necessary to recursively convert all map[interface]interface{} to map[string]interface{}
+func recursiveToJSON(v interface{}) (r interface{}) {
+ switch v := v.(type) {
+ case []interface{}:
+ for i, e := range v {
+ v[i] = recursiveToJSON(e)
+ }
+ r = jsonArray(v)
+ case map[interface{}]interface{}:
+ newMap := make(map[string]interface{}, len(v))
+ for k, e := range v {
+ newMap[k.(string)] = recursiveToJSON(e)
+ }
+ r = jsonMap(newMap)
+ default:
+ r = v
+ }
+ return
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/mysql.go b/vendor/gopkg.in/testfixtures.v2/mysql.go
index 3c96c1b293..8a3e5c9bfe 100644
--- a/vendor/gopkg.in/testfixtures.v2/mysql.go
+++ b/vendor/gopkg.in/testfixtures.v2/mysql.go
@@ -8,6 +8,18 @@ import (
// MySQL is the MySQL helper for this package
type MySQL struct {
baseHelper
+ tables []string
+ tablesChecksum map[string]int64
+}
+
+func (h *MySQL) init(db *sql.DB) error {
+ var err error
+ h.tables, err = h.tableNames(db)
+ if err != nil {
+ return err
+ }
+
+ return nil
}
func (*MySQL) paramType() int {
@@ -18,28 +30,105 @@ func (*MySQL) quoteKeyword(str string) string {
return fmt.Sprintf("`%s`", str)
}
-func (*MySQL) databaseName(db *sql.DB) (dbName string) {
- db.QueryRow("SELECT DATABASE()").Scan(&dbName)
- return
+func (*MySQL) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT DATABASE()").Scan(&dbName)
+ return dbName, err
}
-func (h *MySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *MySQL) tableNames(q queryable) ([]string, error) {
+ query := `
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = ?
+ AND table_type = 'BASE TABLE';
+ `
+ dbName, err := h.databaseName(q)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := q.Query(query, dbName)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var tables []string
+ for rows.Next() {
+ var table string
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
+ tables = append(tables, table)
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return tables, nil
+
+}
+
+func (h *MySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// re-enable after load
- defer db.Exec("SET FOREIGN_KEY_CHECKS = 1")
+ defer func() {
+ if _, err2 := db.Exec("SET FOREIGN_KEY_CHECKS = 1"); err2 != nil && err == nil {
+ err = err2
+ }
+ }()
tx, err := db.Begin()
if err != nil {
return err
}
+ defer tx.Rollback()
if _, err = tx.Exec("SET FOREIGN_KEY_CHECKS = 0"); err != nil {
return err
}
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
return tx.Commit()
}
+
+func (h *MySQL) isTableModified(q queryable, tableName string) (bool, error) {
+ checksum, err := h.getChecksum(q, tableName)
+ if err != nil {
+ return true, err
+ }
+
+ oldChecksum := h.tablesChecksum[tableName]
+
+ return oldChecksum == 0 || checksum != oldChecksum, nil
+}
+
+func (h *MySQL) afterLoad(q queryable) error {
+ if h.tablesChecksum != nil {
+ return nil
+ }
+
+ h.tablesChecksum = make(map[string]int64, len(h.tables))
+ for _, t := range h.tables {
+ checksum, err := h.getChecksum(q, t)
+ if err != nil {
+ return err
+ }
+ h.tablesChecksum[t] = checksum
+ }
+ return nil
+}
+
+func (h *MySQL) getChecksum(q queryable, tableName string) (int64, error) {
+ sql := fmt.Sprintf("CHECKSUM TABLE %s", h.quoteKeyword(tableName))
+ var (
+ table string
+ checksum int64
+ )
+ if err := q.QueryRow(sql).Scan(&table, &checksum); err != nil {
+ return 0, err
+ }
+ return checksum, nil
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/oracle.go b/vendor/gopkg.in/testfixtures.v2/oracle.go
index 59600ebfc5..af5c92ddb3 100644
--- a/vendor/gopkg.in/testfixtures.v2/oracle.go
+++ b/vendor/gopkg.in/testfixtures.v2/oracle.go
@@ -43,54 +43,90 @@ func (*Oracle) quoteKeyword(str string) string {
return fmt.Sprintf("\"%s\"", strings.ToUpper(str))
}
-func (*Oracle) databaseName(db *sql.DB) (dbName string) {
- db.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
- return
+func (*Oracle) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
+ return dbName, err
}
-func (*Oracle) getEnabledConstraints(db *sql.DB) ([]oracleConstraint, error) {
- constraints := make([]oracleConstraint, 0)
- rows, err := db.Query(`
- SELECT table_name, constraint_name
- FROM user_constraints
- WHERE constraint_type = 'R'
- AND status = 'ENABLED'
- `)
+func (*Oracle) tableNames(q queryable) ([]string, error) {
+ query := `
+ SELECT TABLE_NAME
+ FROM USER_TABLES
+ `
+ rows, err := q.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
+
+ var tables []string
+ for rows.Next() {
+ var table string
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
+ tables = append(tables, table)
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return tables, nil
+
+}
+
+func (*Oracle) getEnabledConstraints(q queryable) ([]oracleConstraint, error) {
+ var constraints []oracleConstraint
+ rows, err := q.Query(`
+ SELECT table_name, constraint_name
+ FROM user_constraints
+ WHERE constraint_type = 'R'
+ AND status = 'ENABLED'
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
for rows.Next() {
var constraint oracleConstraint
rows.Scan(&constraint.tableName, &constraint.constraintName)
constraints = append(constraints, constraint)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return constraints, nil
}
-func (*Oracle) getSequences(db *sql.DB) ([]string, error) {
- sequences := make([]string, 0)
- rows, err := db.Query("SELECT sequence_name FROM user_sequences")
+func (*Oracle) getSequences(q queryable) ([]string, error) {
+ var sequences []string
+ rows, err := q.Query("SELECT sequence_name FROM user_sequences")
if err != nil {
return nil, err
}
-
defer rows.Close()
+
for rows.Next() {
var sequence string
- rows.Scan(&sequence)
+ if err = rows.Scan(&sequence); err != nil {
+ return nil, err
+ }
sequences = append(sequences, sequence)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return sequences, nil
}
-func (h *Oracle) resetSequences(db *sql.DB) error {
+func (h *Oracle) resetSequences(q queryable) error {
for _, sequence := range h.sequences {
- _, err := db.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
+ _, err := q.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
if err != nil {
return err
}
- _, err = db.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
+ _, err = q.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
if err != nil {
return err
}
@@ -98,11 +134,14 @@ func (h *Oracle) resetSequences(db *sql.DB) error {
return nil
}
-func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// re-enable after load
defer func() {
for _, c := range h.enabledConstraints {
- db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
+ _, err2 := db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
+ if err2 != nil && err == nil {
+ err = err2
+ }
}
}()
@@ -118,9 +157,9 @@ func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) er
if err != nil {
return err
}
+ defer tx.Rollback()
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
diff --git a/vendor/gopkg.in/testfixtures.v2/postgresql.go b/vendor/gopkg.in/testfixtures.v2/postgresql.go
index ecc5a5cfa8..5386cacb55 100644
--- a/vendor/gopkg.in/testfixtures.v2/postgresql.go
+++ b/vendor/gopkg.in/testfixtures.v2/postgresql.go
@@ -3,6 +3,7 @@ package testfixtures
import (
"database/sql"
"fmt"
+ "strings"
)
// PostgreSQL is the PG helper for this package
@@ -18,6 +19,7 @@ type PostgreSQL struct {
tables []string
sequences []string
nonDeferrableConstraints []pgConstraint
+ tablesChecksum map[string]string
}
type pgConstraint struct {
@@ -28,7 +30,7 @@ type pgConstraint struct {
func (h *PostgreSQL) init(db *sql.DB) error {
var err error
- h.tables, err = h.getTables(db)
+ h.tables, err = h.tableNames(db)
if err != nil {
return err
}
@@ -50,44 +52,57 @@ func (*PostgreSQL) paramType() int {
return paramTypeDollar
}
-func (*PostgreSQL) databaseName(db *sql.DB) (dbName string) {
- db.QueryRow("SELECT current_database()").Scan(&dbName)
- return
+func (*PostgreSQL) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT current_database()").Scan(&dbName)
+ return dbName, err
}
-func (h *PostgreSQL) getTables(db *sql.DB) ([]string, error) {
+func (h *PostgreSQL) tableNames(q queryable) ([]string, error) {
var tables []string
sql := `
-SELECT table_name
-FROM information_schema.tables
-WHERE table_schema = 'public'
- AND table_type = 'BASE TABLE';
-`
- rows, err := db.Query(sql)
+ SELECT pg_namespace.nspname || '.' || pg_class.relname
+ FROM pg_class
+ INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
+ WHERE pg_class.relkind = 'r'
+ AND pg_namespace.nspname NOT IN ('pg_catalog', 'information_schema')
+ AND pg_namespace.nspname NOT LIKE 'pg_toast%';
+ `
+ rows, err := q.Query(sql)
if err != nil {
return nil, err
}
-
defer rows.Close()
+
for rows.Next() {
var table string
- rows.Scan(&table)
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
tables = append(tables, table)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return tables, nil
}
-func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) {
- var sequences []string
+func (h *PostgreSQL) getSequences(q queryable) ([]string, error) {
+ const sql = `
+ SELECT pg_namespace.nspname || '.' || pg_class.relname AS sequence_name
+ FROM pg_class
+ INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
+ WHERE pg_class.relkind = 'S'
+ `
- sql := "SELECT relname FROM pg_class WHERE relkind = 'S'"
- rows, err := db.Query(sql)
+ rows, err := q.Query(sql)
if err != nil {
return nil, err
}
-
defer rows.Close()
+
+ var sequences []string
for rows.Next() {
var sequence string
if err = rows.Scan(&sequence); err != nil {
@@ -95,18 +110,22 @@ func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) {
}
sequences = append(sequences, sequence)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return sequences, nil
}
-func (*PostgreSQL) getNonDeferrableConstraints(db *sql.DB) ([]pgConstraint, error) {
+func (*PostgreSQL) getNonDeferrableConstraints(q queryable) ([]pgConstraint, error) {
var constraints []pgConstraint
sql := `
-SELECT table_name, constraint_name
-FROM information_schema.table_constraints
-WHERE constraint_type = 'FOREIGN KEY'
- AND is_deferrable = 'NO'`
- rows, err := db.Query(sql)
+ SELECT table_schema || '.' || table_name, constraint_name
+ FROM information_schema.table_constraints
+ WHERE constraint_type = 'FOREIGN KEY'
+ AND is_deferrable = 'NO'
+ `
+ rows, err := q.Query(sql)
if err != nil {
return nil, err
}
@@ -114,23 +133,27 @@ WHERE constraint_type = 'FOREIGN KEY'
defer rows.Close()
for rows.Next() {
var constraint pgConstraint
- err = rows.Scan(&constraint.tableName, &constraint.constraintName)
- if err != nil {
+ if err = rows.Scan(&constraint.tableName, &constraint.constraintName); err != nil {
return nil, err
}
constraints = append(constraints, constraint)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return constraints, nil
}
-func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) (err error) {
defer func() {
// re-enable triggers after load
var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL;", h.quoteKeyword(table))
}
- db.Exec(sql)
+ if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+ err = err2
+ }
}()
tx, err := db.Begin()
@@ -154,14 +177,16 @@ func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error {
return tx.Commit()
}
-func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) (err error) {
defer func() {
// ensure constraint being not deferrable again after load
var sql string
for _, constraint := range h.nonDeferrableConstraints {
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
}
- db.Exec(sql)
+ if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+ err = err2
+ }
}()
var sql string
@@ -176,28 +201,31 @@ func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction)
if err != nil {
return err
}
+ defer tx.Rollback()
if _, err = tx.Exec("SET CONSTRAINTS ALL DEFERRED"); err != nil {
- return nil
+ return err
}
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
return tx.Commit()
}
-func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// ensure sequences being reset after load
- defer h.resetSequences(db)
+ defer func() {
+ if err2 := h.resetSequences(db); err2 != nil && err == nil {
+ err = err2
+ }
+ }()
if h.UseAlterConstraint {
return h.makeConstraintsDeferrable(db, loadFn)
- } else {
- return h.disableTriggers(db, loadFn)
}
+ return h.disableTriggers(db, loadFn)
}
func (h *PostgreSQL) resetSequences(db *sql.DB) error {
@@ -209,3 +237,53 @@ func (h *PostgreSQL) resetSequences(db *sql.DB) error {
}
return nil
}
+
+func (h *PostgreSQL) isTableModified(q queryable, tableName string) (bool, error) {
+ checksum, err := h.getChecksum(q, tableName)
+ if err != nil {
+ return false, err
+ }
+
+ oldChecksum := h.tablesChecksum[tableName]
+
+ return oldChecksum == "" || checksum != oldChecksum, nil
+}
+
+func (h *PostgreSQL) afterLoad(q queryable) error {
+ if h.tablesChecksum != nil {
+ return nil
+ }
+
+ h.tablesChecksum = make(map[string]string, len(h.tables))
+ for _, t := range h.tables {
+ checksum, err := h.getChecksum(q, t)
+ if err != nil {
+ return err
+ }
+ h.tablesChecksum[t] = checksum
+ }
+ return nil
+}
+
+func (h *PostgreSQL) getChecksum(q queryable, tableName string) (string, error) {
+ sqlStr := fmt.Sprintf(`
+ SELECT md5(CAST((array_agg(t.*)) AS TEXT))
+ FROM %s AS t
+ `,
+ h.quoteKeyword(tableName),
+ )
+
+ var checksum sql.NullString
+ if err := q.QueryRow(sqlStr).Scan(&checksum); err != nil {
+ return "", err
+ }
+ return checksum.String, nil
+}
+
+func (*PostgreSQL) quoteKeyword(s string) string {
+ parts := strings.Split(s, ".")
+ for i, p := range parts {
+ parts[i] = fmt.Sprintf(`"%s"`, p)
+ }
+ return strings.Join(parts, ".")
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/sqlite.go b/vendor/gopkg.in/testfixtures.v2/sqlite.go
index 4d7fa4fdab..150f014108 100644
--- a/vendor/gopkg.in/testfixtures.v2/sqlite.go
+++ b/vendor/gopkg.in/testfixtures.v2/sqlite.go
@@ -14,23 +14,59 @@ func (*SQLite) paramType() int {
return paramTypeQuestion
}
-func (*SQLite) databaseName(db *sql.DB) (dbName string) {
+func (*SQLite) databaseName(q queryable) (string, error) {
var seq int
- var main string
- db.QueryRow("PRAGMA database_list").Scan(&seq, &main, &dbName)
+ var main, dbName string
+ err := q.QueryRow("PRAGMA database_list").Scan(&seq, &main, &dbName)
+ if err != nil {
+ return "", err
+ }
dbName = filepath.Base(dbName)
- return
+ return dbName, nil
}
-func (*SQLite) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
- tx, err := db.Begin()
+func (*SQLite) tableNames(q queryable) ([]string, error) {
+ query := `
+ SELECT name
+ FROM sqlite_master
+ WHERE type = 'table';
+ `
+ rows, err := q.Query(query)
if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var tables []string
+ for rows.Next() {
+ var table string
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
+ tables = append(tables, table)
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return tables, nil
+}
+
+func (*SQLite) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
+ defer func() {
+ if _, err2 := db.Exec("PRAGMA defer_foreign_keys = OFF"); err2 != nil && err == nil {
+ err = err2
+ }
+ }()
+
+ if _, err = db.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
return err
}
- if _, err = tx.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
+ tx, err := db.Begin()
+ if err != nil {
return err
}
+ defer tx.Rollback()
if err = loadFn(tx); err != nil {
return err
diff --git a/vendor/gopkg.in/testfixtures.v2/sqlserver.go b/vendor/gopkg.in/testfixtures.v2/sqlserver.go
index 1399f5c400..d2fc854797 100644
--- a/vendor/gopkg.in/testfixtures.v2/sqlserver.go
+++ b/vendor/gopkg.in/testfixtures.v2/sqlserver.go
@@ -3,6 +3,7 @@ package testfixtures
import (
"database/sql"
"fmt"
+ "strings"
)
// SQLServer is the helper for SQL Server for this package.
@@ -16,7 +17,7 @@ type SQLServer struct {
func (h *SQLServer) init(db *sql.DB) error {
var err error
- h.tables, err = h.getTables(db)
+ h.tables, err = h.tableNames(db)
if err != nil {
return err
}
@@ -28,46 +29,62 @@ func (*SQLServer) paramType() int {
return paramTypeQuestion
}
-func (*SQLServer) quoteKeyword(str string) string {
- return fmt.Sprintf("[%s]", str)
+func (*SQLServer) quoteKeyword(s string) string {
+ parts := strings.Split(s, ".")
+ for i, p := range parts {
+ parts[i] = fmt.Sprintf(`[%s]`, p)
+ }
+ return strings.Join(parts, ".")
}
-func (*SQLServer) databaseName(db *sql.DB) (dbname string) {
- db.QueryRow("SELECT DB_NAME()").Scan(&dbname)
- return
+func (*SQLServer) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT DB_NAME()").Scan(&dbName)
+ return dbName, err
}
-func (*SQLServer) getTables(db *sql.DB) ([]string, error) {
- rows, err := db.Query("SELECT table_name FROM information_schema.tables")
+func (*SQLServer) tableNames(q queryable) ([]string, error) {
+ rows, err := q.Query("SELECT table_schema + '.' + table_name FROM information_schema.tables")
if err != nil {
return nil, err
}
-
- tables := make([]string, 0)
defer rows.Close()
+
+ var tables []string
for rows.Next() {
var table string
- rows.Scan(&table)
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
tables = append(tables, table)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return tables, nil
}
-func (*SQLServer) tableHasIdentityColumn(tx *sql.Tx, tableName string) bool {
+func (h *SQLServer) tableHasIdentityColumn(q queryable, tableName string) bool {
sql := `
-SELECT COUNT(*)
-FROM SYS.IDENTITY_COLUMNS
-WHERE OBJECT_NAME(OBJECT_ID) = ?
-`
+ SELECT COUNT(*)
+ FROM SYS.IDENTITY_COLUMNS
+ WHERE OBJECT_ID = OBJECT_ID(?)
+ `
var count int
- tx.QueryRow(sql, tableName).Scan(&count)
+ q.QueryRow(sql, h.quoteKeyword(tableName)).Scan(&count)
return count > 0
}
-func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) error {
+func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) (err error) {
if h.tableHasIdentityColumn(tx, tableName) {
- defer tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
+ defer func() {
+ _, err2 := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
+ if err2 != nil && err == nil {
+ err = err2
+ }
+ }()
+
_, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
if err != nil {
return err
@@ -76,19 +93,19 @@ func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() e
return fn()
}
-func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// ensure the triggers are re-enable after all
defer func() {
- sql := ""
+ var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table))
}
- if _, err := db.Exec(sql); err != nil {
- fmt.Printf("Error on re-enabling constraints: %v\n", err)
+ if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+ err = err2
}
}()
- sql := ""
+ var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table))
}
@@ -100,11 +117,19 @@ func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction)
if err != nil {
return err
}
+ defer tx.Rollback()
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
return tx.Commit()
}
+
+// splitter is a batchSplitter interface implementation. We need it for
+// SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...`
+// could not be executed in the same batch.
+// See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches
+func (*SQLServer) splitter() []byte {
+ return []byte("GO\n")
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/testfixtures.go b/vendor/gopkg.in/testfixtures.v2/testfixtures.go
index bd95580f2a..dfc59c1efc 100644
--- a/vendor/gopkg.in/testfixtures.v2/testfixtures.go
+++ b/vendor/gopkg.in/testfixtures.v2/testfixtures.go
@@ -2,7 +2,6 @@ package testfixtures
import (
"database/sql"
- "errors"
"fmt"
"io/ioutil"
"path"
@@ -33,22 +32,10 @@ type insertSQL struct {
}
var (
- // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
- ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
-
- // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
- ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
-
- // ErrKeyIsNotString is returned when a record is not of type string
- ErrKeyIsNotString = errors.New("Record map key is not string")
-
- // ErrNotTestDatabase is returned when the database name doesn't contains "test"
- ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
-
dbnameRegexp = regexp.MustCompile("(?i)test")
)
-// NewFolder craetes a context for all fixtures in a given folder into the database:
+// NewFolder creates a context for all fixtures in a given folder into the database:
// NewFolder(db, &PostgreSQL{}, "my/fixtures/folder")
func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
fixtures, err := fixturesFromFolder(folderName)
@@ -64,7 +51,7 @@ func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
return c, nil
}
-// NewFiles craetes a context for all specified fixtures files into database:
+// NewFiles creates a context for all specified fixtures files into database:
// NewFiles(db, &PostgreSQL{},
// "fixtures/customers.yml",
// "fixtures/orders.yml"
@@ -102,27 +89,55 @@ func newContext(db *sql.DB, helper Helper, fixtures []*fixtureFile) (*Context, e
return c, nil
}
+// DetectTestDatabase returns nil if databaseName matches regexp
+// if err := fixtures.DetectTestDatabase(); err != nil {
+// log.Fatal(err)
+// }
+func (c *Context) DetectTestDatabase() error {
+ dbName, err := c.helper.databaseName(c.db)
+ if err != nil {
+ return err
+ }
+ if !dbnameRegexp.MatchString(dbName) {
+ return ErrNotTestDatabase
+ }
+ return nil
+}
+
// Load wipes and after load all fixtures in the database.
// if err := fixtures.Load(); err != nil {
// log.Fatal(err)
// }
func (c *Context) Load() error {
if !skipDatabaseNameCheck {
- if !dbnameRegexp.MatchString(c.helper.databaseName(c.db)) {
- return ErrNotTestDatabase
+ if err := c.DetectTestDatabase(); err != nil {
+ return err
}
}
err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error {
for _, file := range c.fixturesFiles {
+ modified, err := c.helper.isTableModified(tx, file.fileNameWithoutExtension())
+ if err != nil {
+ return err
+ }
+ if !modified {
+ continue
+ }
if err := file.delete(tx, c.helper); err != nil {
return err
}
- err := c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
- for _, i := range file.insertSQLs {
+ err = c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
+ for j, i := range file.insertSQLs {
if _, err := tx.Exec(i.sql, i.params...); err != nil {
- return err
+ return &InsertError{
+ Err: err,
+ File: file.fileName,
+ Index: j,
+ SQL: i.sql,
+ Params: i.params,
+ }
}
}
return nil
@@ -133,7 +148,10 @@ func (c *Context) Load() error {
}
return nil
})
- return err
+ if err != nil {
+ return err
+ }
+ return c.helper.afterLoad(c.db)
}
func (c *Context) buildInsertSQLs() error {
@@ -204,25 +222,33 @@ func (f *fixtureFile) buildInsertSQL(h Helper, record map[interface{}]interface{
sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr))
+ // if string, try convert to SQL or time
+ // if map or array, convert to json
+ switch v := value.(type) {
+ case string:
+ if strings.HasPrefix(v, "RAW=") {
+ sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
+ continue
+ }
+
+ if t, err := tryStrToDate(v); err == nil {
+ value = t
+ }
+ case []interface{}, map[interface{}]interface{}:
+ value = recursiveToJSON(v)
+ }
+
switch h.paramType() {
case paramTypeDollar:
sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
case paramTypeQuestion:
sqlValues = append(sqlValues, "?")
case paramTypeColon:
- switch {
- case isDateTime(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD HH24:MI:SS')", i))
- case isDate(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD')", i))
- case isTime(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'HH24:MI:SS')", i))
- default:
- sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
- }
+ sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
}
- i++
+
values = append(values, value)
+ i++
}
sqlStr = fmt.Sprintf(
diff --git a/vendor/gopkg.in/testfixtures.v2/time.go b/vendor/gopkg.in/testfixtures.v2/time.go
index 6796707500..8c5cba1d03 100644
--- a/vendor/gopkg.in/testfixtures.v2/time.go
+++ b/vendor/gopkg.in/testfixtures.v2/time.go
@@ -1,36 +1,34 @@
package testfixtures
-import "regexp"
-
-var (
- regexpDate = regexp.MustCompile("\\d\\d\\d\\d-\\d\\d-\\d\\d")
- regexpDateTime = regexp.MustCompile("\\d\\d\\d\\d-\\d\\d-\\d\\d \\d\\d:\\d\\d:\\d\\d")
- regexpTime = regexp.MustCompile("\\d\\d:\\d\\d:\\d\\d")
+import (
+ "errors"
+ "time"
)
-func isDate(value interface{}) bool {
- str, isStr := value.(string)
- if !isStr {
- return false
- }
-
- return regexpDate.MatchString(str)
+var timeFormats = []string{
+ "2006-01-02",
+ "2006-01-02 15:04",
+ "2006-01-02 15:04:05",
+ "20060102",
+ "20060102 15:04",
+ "20060102 15:04:05",
+ "02/01/2006",
+ "02/01/2006 15:04",
+ "02/01/2006 15:04:05",
+ "2006-01-02T15:04-07:00",
+ "2006-01-02T15:04:05-07:00",
}
-func isDateTime(value interface{}) bool {
- str, isStr := value.(string)
- if !isStr {
- return false
- }
+// ErrCouldNotConvertToTime is returns when a string is not a reconizable time format
+var ErrCouldNotConvertToTime = errors.New("Could not convert string to time")
- return regexpDateTime.MatchString(str)
-}
-
-func isTime(value interface{}) bool {
- str, isStr := value.(string)
- if !isStr {
- return false
+func tryStrToDate(s string) (time.Time, error) {
+ for _, f := range timeFormats {
+ t, err := time.ParseInLocation(f, s, time.Local)
+ if err != nil {
+ continue
+ }
+ return t, nil
}
-
- return regexpTime.MatchString(str)
+ return time.Time{}, ErrCouldNotConvertToTime
}