]> source.dussan.org Git - gitea.git/commitdiff
Upgrade gopkg.in/testfixtures.v2 (#4999)
authorMura Li <typeless@users.noreply.github.com>
Tue, 2 Oct 2018 19:20:02 +0000 (03:20 +0800)
committertechknowlogick <hello@techknowlogick.com>
Tue, 2 Oct 2018 19:20:02 +0000 (15:20 -0400)
13 files changed:
Gopkg.lock
vendor/gopkg.in/testfixtures.v2/deprecated.go
vendor/gopkg.in/testfixtures.v2/errors.go [new file with mode: 0644]
vendor/gopkg.in/testfixtures.v2/generate.go [new file with mode: 0644]
vendor/gopkg.in/testfixtures.v2/helper.go
vendor/gopkg.in/testfixtures.v2/json.go [new file with mode: 0644]
vendor/gopkg.in/testfixtures.v2/mysql.go
vendor/gopkg.in/testfixtures.v2/oracle.go
vendor/gopkg.in/testfixtures.v2/postgresql.go
vendor/gopkg.in/testfixtures.v2/sqlite.go
vendor/gopkg.in/testfixtures.v2/sqlserver.go
vendor/gopkg.in/testfixtures.v2/testfixtures.go
vendor/gopkg.in/testfixtures.v2/time.go

index e6eb721fd346445d1a7599d7a8e463eb0c0a675e..39b44c65dc84efc9e48fcab99e944d7500904648 100644 (file)
   version = "v2.3.2"
 
 [[projects]]
-  digest = "1:e2144032dcf8e856fb733151391669dc2d32d79ebb69e3c4151efa605f5e8a01"
+  digest = "1:9c541fc507676a69ea8aaed1af53278a5241d26ce0f192c993fec2ac5b78f795"
   name = "gopkg.in/testfixtures.v2"
   packages = ["."]
   pruneopts = "NUT"
-  revision = "b9ef14dc461bf934d8df2dfc6f1f456be5664cca"
-  version = "v2.0.0"
+  revision = "fa3fb89109b0b31957a5430cef3e93e535de362b"
+  version = "v2.5.0"
 
 [[projects]]
   digest = "1:ad6f94355d292690137613735965bd3688844880fdab90eccf66321910344942"
index b83eeef436f7e0bd60a5dbef854d4453f1d90e7e..16e0969e3394efa7934c7528d235c0d4bdaa3f60 100644 (file)
@@ -5,22 +5,38 @@ import (
 )
 
 type (
-       DataBaseHelper Helper // Deprecated: Use Helper instead
+       // DataBaseHelper is the helper interface
+       // Deprecated: Use Helper instead
+       DataBaseHelper Helper
 
-       PostgreSQLHelper struct { // Deprecated: Use PostgreSQL{} instead
+       // PostgreSQLHelper is the PostgreSQL helper
+       // Deprecated: Use PostgreSQL{} instead
+       PostgreSQLHelper struct {
                PostgreSQL
                UseAlterConstraint bool
        }
-       MySQLHelper struct { // Deprecated: Use MySQL{} instead
+
+       // MySQLHelper is the MySQL helper
+       // Deprecated: Use MySQL{} instead
+       MySQLHelper struct {
                MySQL
        }
-       SQLiteHelper struct { // Deprecated: Use SQLite{} instead
+
+       // SQLiteHelper is the SQLite helper
+       // Deprecated: Use SQLite{} instead
+       SQLiteHelper struct {
                SQLite
        }
-       SQLServerHelper struct { // Deprecated: Use SQLServer{} instead
+
+       // SQLServerHelper is the SQLServer helper
+       // Deprecated: Use SQLServer{} instead
+       SQLServerHelper struct {
                SQLServer
        }
-       OracleHelper struct { // Deprecated: Use Oracle{} instead
+
+       // OracleHelper is the Oracle helper
+       // Deprecated: Use Oracle{} instead
+       OracleHelper struct {
                Oracle
        }
 )
diff --git a/vendor/gopkg.in/testfixtures.v2/errors.go b/vendor/gopkg.in/testfixtures.v2/errors.go
new file mode 100644 (file)
index 0000000..17eb284
--- /dev/null
@@ -0,0 +1,41 @@
+package testfixtures
+
+import (
+       "errors"
+       "fmt"
+)
+
+var (
+       // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
+       ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
+
+       // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
+       ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
+
+       // ErrKeyIsNotString is returned when a record is not of type string
+       ErrKeyIsNotString = errors.New("Record map key is not string")
+
+       // ErrNotTestDatabase is returned when the database name doesn't contains "test"
+       ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
+)
+
+// InsertError will be returned if any error happens on database while
+// inserting the record
+type InsertError struct {
+       Err    error
+       File   string
+       Index  int
+       SQL    string
+       Params []interface{}
+}
+
+func (e *InsertError) Error() string {
+       return fmt.Sprintf(
+               "testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",
+               e.Err,
+               e.File,
+               e.Index,
+               e.SQL,
+               e.Params,
+       )
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/generate.go b/vendor/gopkg.in/testfixtures.v2/generate.go
new file mode 100644 (file)
index 0000000..8448140
--- /dev/null
@@ -0,0 +1,110 @@
+package testfixtures
+
+import (
+       "database/sql"
+       "fmt"
+       "os"
+       "path"
+       "unicode/utf8"
+
+       "gopkg.in/yaml.v2"
+)
+
+// TableInfo is settings for generating a fixture for table.
+type TableInfo struct {
+       Name  string // Table name
+       Where string // A condition for extracting records. If this value is empty, extracts all records.
+}
+
+func (ti *TableInfo) whereClause() string {
+       if ti.Where == "" {
+               return ""
+       }
+       return fmt.Sprintf(" WHERE %s", ti.Where)
+}
+
+// GenerateFixtures generates fixtures for the current contents of a database, and saves
+// them to the specified directory
+func GenerateFixtures(db *sql.DB, helper Helper, dir string) error {
+       tables, err := helper.tableNames(db)
+       if err != nil {
+               return err
+       }
+       for _, table := range tables {
+               filename := path.Join(dir, table+".yml")
+               if err := generateFixturesForTable(db, helper, &TableInfo{Name: table}, filename); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+// GenerateFixturesForTables generates fixtures for the current contents of specified tables in a database, and saves
+// them to the specified directory
+func GenerateFixturesForTables(db *sql.DB, tables []*TableInfo, helper Helper, dir string) error {
+       for _, table := range tables {
+               filename := path.Join(dir, table.Name+".yml")
+               if err := generateFixturesForTable(db, helper, table, filename); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+func generateFixturesForTable(db *sql.DB, h Helper, table *TableInfo, filename string) error {
+       query := fmt.Sprintf("SELECT * FROM %s%s", h.quoteKeyword(table.Name), table.whereClause())
+       rows, err := db.Query(query)
+       if err != nil {
+               return err
+       }
+       defer rows.Close()
+
+       columns, err := rows.Columns()
+       if err != nil {
+               return err
+       }
+
+       fixtures := make([]interface{}, 0, 10)
+       for rows.Next() {
+               entries := make([]interface{}, len(columns))
+               entryPtrs := make([]interface{}, len(entries))
+               for i := range entries {
+                       entryPtrs[i] = &entries[i]
+               }
+               if err := rows.Scan(entryPtrs...); err != nil {
+                       return err
+               }
+
+               entryMap := make(map[string]interface{}, len(entries))
+               for i, column := range columns {
+                       entryMap[column] = convertValue(entries[i])
+               }
+               fixtures = append(fixtures, entryMap)
+       }
+       if err = rows.Err(); err != nil {
+               return err
+       }
+
+       f, err := os.Create(filename)
+       if err != nil {
+               return err
+       }
+       defer f.Close()
+
+       marshaled, err := yaml.Marshal(fixtures)
+       if err != nil {
+               return err
+       }
+       _, err = f.Write(marshaled)
+       return err
+}
+
+func convertValue(value interface{}) interface{} {
+       switch v := value.(type) {
+       case []byte:
+               if utf8.Valid(v) {
+                       return string(v)
+               }
+       }
+       return value
+}
index f1c19f29d4e9f0c0e79b40ae0dc1c91d3382626e..bd1ebba62ef494374f40c4097878da61303062f8 100644 (file)
@@ -18,21 +18,56 @@ type Helper interface {
        init(*sql.DB) error
        disableReferentialIntegrity(*sql.DB, loadFunction) error
        paramType() int
-       databaseName(*sql.DB) string
+       databaseName(queryable) (string, error)
+       tableNames(queryable) ([]string, error)
+       isTableModified(queryable, string) (bool, error)
+       afterLoad(queryable) error
        quoteKeyword(string) string
        whileInsertOnTable(*sql.Tx, string, func() error) error
 }
 
+type queryable interface {
+       Exec(string, ...interface{}) (sql.Result, error)
+       Query(string, ...interface{}) (*sql.Rows, error)
+       QueryRow(string, ...interface{}) *sql.Row
+}
+
+// batchSplitter is an interface with method which returns byte slice for
+// splitting SQL batches. This need to split sql statements and run its
+// separately.
+//
+// For Microsoft SQL Server batch splitter is "GO". For details see
+// https://docs.microsoft.com/en-us/sql/t-sql/language-elements/sql-server-utilities-statements-go
+type batchSplitter interface {
+       splitter() []byte
+}
+
+var (
+       _ Helper = &MySQL{}
+       _ Helper = &PostgreSQL{}
+       _ Helper = &SQLite{}
+       _ Helper = &Oracle{}
+       _ Helper = &SQLServer{}
+)
+
 type baseHelper struct{}
 
-func (*baseHelper) init(_ *sql.DB) error {
+func (baseHelper) init(_ *sql.DB) error {
        return nil
 }
 
-func (*baseHelper) quoteKeyword(str string) string {
+func (baseHelper) quoteKeyword(str string) string {
        return fmt.Sprintf(`"%s"`, str)
 }
 
-func (*baseHelper) whileInsertOnTable(_ *sql.Tx, _ string, fn func() error) error {
+func (baseHelper) whileInsertOnTable(_ *sql.Tx, _ string, fn func() error) error {
        return fn()
 }
+
+func (baseHelper) isTableModified(_ queryable, _ string) (bool, error) {
+       return true, nil
+}
+
+func (baseHelper) afterLoad(_ queryable) error {
+       return nil
+}
diff --git a/vendor/gopkg.in/testfixtures.v2/json.go b/vendor/gopkg.in/testfixtures.v2/json.go
new file mode 100644 (file)
index 0000000..f954a17
--- /dev/null
@@ -0,0 +1,44 @@
+package testfixtures
+
+import (
+       "database/sql/driver"
+       "encoding/json"
+)
+
+var (
+       _ driver.Valuer = jsonArray{}
+       _ driver.Valuer = jsonMap{}
+)
+
+type jsonArray []interface{}
+
+func (a jsonArray) Value() (driver.Value, error) {
+       return json.Marshal(a)
+}
+
+type jsonMap map[string]interface{}
+
+func (m jsonMap) Value() (driver.Value, error) {
+       return json.Marshal(m)
+}
+
+// Go refuses to convert map[interface{}]interface{} to JSON because JSON only support string keys
+// So it's necessary to recursively convert all map[interface]interface{} to map[string]interface{}
+func recursiveToJSON(v interface{}) (r interface{}) {
+       switch v := v.(type) {
+       case []interface{}:
+               for i, e := range v {
+                       v[i] = recursiveToJSON(e)
+               }
+               r = jsonArray(v)
+       case map[interface{}]interface{}:
+               newMap := make(map[string]interface{}, len(v))
+               for k, e := range v {
+                       newMap[k.(string)] = recursiveToJSON(e)
+               }
+               r = jsonMap(newMap)
+       default:
+               r = v
+       }
+       return
+}
index 3c96c1b293f98489dfe1c0a5db022864b3fc25b7..8a3e5c9bfe9e95623f6670b05144092ad77f2435 100644 (file)
@@ -8,6 +8,18 @@ import (
 // MySQL is the MySQL helper for this package
 type MySQL struct {
        baseHelper
+       tables         []string
+       tablesChecksum map[string]int64
+}
+
+func (h *MySQL) init(db *sql.DB) error {
+       var err error
+       h.tables, err = h.tableNames(db)
+       if err != nil {
+               return err
+       }
+
+       return nil
 }
 
 func (*MySQL) paramType() int {
@@ -18,28 +30,105 @@ func (*MySQL) quoteKeyword(str string) string {
        return fmt.Sprintf("`%s`", str)
 }
 
-func (*MySQL) databaseName(db *sql.DB) (dbName string) {
-       db.QueryRow("SELECT DATABASE()").Scan(&dbName)
-       return
+func (*MySQL) databaseName(q queryable) (string, error) {
+       var dbName string
+       err := q.QueryRow("SELECT DATABASE()").Scan(&dbName)
+       return dbName, err
 }
 
-func (h *MySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *MySQL) tableNames(q queryable) ([]string, error) {
+       query := `
+               SELECT table_name
+               FROM information_schema.tables
+               WHERE table_schema = ?
+                 AND table_type = 'BASE TABLE';
+       `
+       dbName, err := h.databaseName(q)
+       if err != nil {
+               return nil, err
+       }
+
+       rows, err := q.Query(query, dbName)
+       if err != nil {
+               return nil, err
+       }
+       defer rows.Close()
+
+       var tables []string
+       for rows.Next() {
+               var table string
+               if err = rows.Scan(&table); err != nil {
+                       return nil, err
+               }
+               tables = append(tables, table)
+       }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
+       return tables, nil
+
+}
+
+func (h *MySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
        // re-enable after load
-       defer db.Exec("SET FOREIGN_KEY_CHECKS = 1")
+       defer func() {
+               if _, err2 := db.Exec("SET FOREIGN_KEY_CHECKS = 1"); err2 != nil && err == nil {
+                       err = err2
+               }
+       }()
 
        tx, err := db.Begin()
        if err != nil {
                return err
        }
+       defer tx.Rollback()
 
        if _, err = tx.Exec("SET FOREIGN_KEY_CHECKS = 0"); err != nil {
                return err
        }
 
        if err = loadFn(tx); err != nil {
-               tx.Rollback()
                return err
        }
 
        return tx.Commit()
 }
+
+func (h *MySQL) isTableModified(q queryable, tableName string) (bool, error) {
+       checksum, err := h.getChecksum(q, tableName)
+       if err != nil {
+               return true, err
+       }
+
+       oldChecksum := h.tablesChecksum[tableName]
+
+       return oldChecksum == 0 || checksum != oldChecksum, nil
+}
+
+func (h *MySQL) afterLoad(q queryable) error {
+       if h.tablesChecksum != nil {
+               return nil
+       }
+
+       h.tablesChecksum = make(map[string]int64, len(h.tables))
+       for _, t := range h.tables {
+               checksum, err := h.getChecksum(q, t)
+               if err != nil {
+                       return err
+               }
+               h.tablesChecksum[t] = checksum
+       }
+       return nil
+}
+
+func (h *MySQL) getChecksum(q queryable, tableName string) (int64, error) {
+       sql := fmt.Sprintf("CHECKSUM TABLE %s", h.quoteKeyword(tableName))
+       var (
+               table    string
+               checksum int64
+       )
+       if err := q.QueryRow(sql).Scan(&table, &checksum); err != nil {
+               return 0, err
+       }
+       return checksum, nil
+}
index 59600ebfc52693a4278dd65151a9ed086c4afe06..af5c92ddb3cd6eba1ebb91520ac2964a7d42edde 100644 (file)
@@ -43,54 +43,90 @@ func (*Oracle) quoteKeyword(str string) string {
        return fmt.Sprintf("\"%s\"", strings.ToUpper(str))
 }
 
-func (*Oracle) databaseName(db *sql.DB) (dbName string) {
-       db.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
-       return
+func (*Oracle) databaseName(q queryable) (string, error) {
+       var dbName string
+       err := q.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
+       return dbName, err
 }
 
-func (*Oracle) getEnabledConstraints(db *sql.DB) ([]oracleConstraint, error) {
-       constraints := make([]oracleConstraint, 0)
-       rows, err := db.Query(`
-        SELECT table_name, constraint_name
-        FROM user_constraints
-        WHERE constraint_type = 'R'
-          AND status = 'ENABLED'
-    `)
+func (*Oracle) tableNames(q queryable) ([]string, error) {
+       query := `
+               SELECT TABLE_NAME
+               FROM USER_TABLES
+       `
+       rows, err := q.Query(query)
        if err != nil {
                return nil, err
        }
        defer rows.Close()
+
+       var tables []string
+       for rows.Next() {
+               var table string
+               if err = rows.Scan(&table); err != nil {
+                       return nil, err
+               }
+               tables = append(tables, table)
+       }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
+       return tables, nil
+
+}
+
+func (*Oracle) getEnabledConstraints(q queryable) ([]oracleConstraint, error) {
+       var constraints []oracleConstraint
+       rows, err := q.Query(`
+               SELECT table_name, constraint_name
+               FROM user_constraints
+               WHERE constraint_type = 'R'
+                 AND status = 'ENABLED'
+       `)
+       if err != nil {
+               return nil, err
+       }
+       defer rows.Close()
+
        for rows.Next() {
                var constraint oracleConstraint
                rows.Scan(&constraint.tableName, &constraint.constraintName)
                constraints = append(constraints, constraint)
        }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
        return constraints, nil
 }
 
-func (*Oracle) getSequences(db *sql.DB) ([]string, error) {
-       sequences := make([]string, 0)
-       rows, err := db.Query("SELECT sequence_name FROM user_sequences")
+func (*Oracle) getSequences(q queryable) ([]string, error) {
+       var sequences []string
+       rows, err := q.Query("SELECT sequence_name FROM user_sequences")
        if err != nil {
                return nil, err
        }
-
        defer rows.Close()
+
        for rows.Next() {
                var sequence string
-               rows.Scan(&sequence)
+               if err = rows.Scan(&sequence); err != nil {
+                       return nil, err
+               }
                sequences = append(sequences, sequence)
        }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
        return sequences, nil
 }
 
-func (h *Oracle) resetSequences(db *sql.DB) error {
+func (h *Oracle) resetSequences(q queryable) error {
        for _, sequence := range h.sequences {
-               _, err := db.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
+               _, err := q.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
                if err != nil {
                        return err
                }
-               _, err = db.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
+               _, err = q.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
                if err != nil {
                        return err
                }
@@ -98,11 +134,14 @@ func (h *Oracle) resetSequences(db *sql.DB) error {
        return nil
 }
 
-func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
        // re-enable after load
        defer func() {
                for _, c := range h.enabledConstraints {
-                       db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
+                       _, err2 := db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
+                       if err2 != nil && err == nil {
+                               err = err2
+                       }
                }
        }()
 
@@ -118,9 +157,9 @@ func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) er
        if err != nil {
                return err
        }
+       defer tx.Rollback()
 
        if err = loadFn(tx); err != nil {
-               tx.Rollback()
                return err
        }
 
index ecc5a5cfa8c1162ef407a092ca76b6248703b779..5386cacb55fb2d1067a14d9ceedf9c47c16b8f61 100644 (file)
@@ -3,6 +3,7 @@ package testfixtures
 import (
        "database/sql"
        "fmt"
+       "strings"
 )
 
 // PostgreSQL is the PG helper for this package
@@ -18,6 +19,7 @@ type PostgreSQL struct {
        tables                   []string
        sequences                []string
        nonDeferrableConstraints []pgConstraint
+       tablesChecksum           map[string]string
 }
 
 type pgConstraint struct {
@@ -28,7 +30,7 @@ type pgConstraint struct {
 func (h *PostgreSQL) init(db *sql.DB) error {
        var err error
 
-       h.tables, err = h.getTables(db)
+       h.tables, err = h.tableNames(db)
        if err != nil {
                return err
        }
@@ -50,44 +52,57 @@ func (*PostgreSQL) paramType() int {
        return paramTypeDollar
 }
 
-func (*PostgreSQL) databaseName(db *sql.DB) (dbName string) {
-       db.QueryRow("SELECT current_database()").Scan(&dbName)
-       return
+func (*PostgreSQL) databaseName(q queryable) (string, error) {
+       var dbName string
+       err := q.QueryRow("SELECT current_database()").Scan(&dbName)
+       return dbName, err
 }
 
-func (h *PostgreSQL) getTables(db *sql.DB) ([]string, error) {
+func (h *PostgreSQL) tableNames(q queryable) ([]string, error) {
        var tables []string
 
        sql := `
-SELECT table_name
-FROM information_schema.tables
-WHERE table_schema = 'public'
-  AND table_type = 'BASE TABLE';
-`
-       rows, err := db.Query(sql)
+               SELECT pg_namespace.nspname || '.' || pg_class.relname
+               FROM pg_class
+               INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
+               WHERE pg_class.relkind = 'r'
+                 AND pg_namespace.nspname NOT IN ('pg_catalog', 'information_schema')
+                 AND pg_namespace.nspname NOT LIKE 'pg_toast%';
+       `
+       rows, err := q.Query(sql)
        if err != nil {
                return nil, err
        }
-
        defer rows.Close()
+
        for rows.Next() {
                var table string
-               rows.Scan(&table)
+               if err = rows.Scan(&table); err != nil {
+                       return nil, err
+               }
                tables = append(tables, table)
        }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
        return tables, nil
 }
 
-func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) {
-       var sequences []string
+func (h *PostgreSQL) getSequences(q queryable) ([]string, error) {
+       const sql = `
+               SELECT pg_namespace.nspname || '.' || pg_class.relname AS sequence_name
+               FROM pg_class
+               INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
+               WHERE pg_class.relkind = 'S'
+       `
 
-       sql := "SELECT relname FROM pg_class WHERE relkind = 'S'"
-       rows, err := db.Query(sql)
+       rows, err := q.Query(sql)
        if err != nil {
                return nil, err
        }
-
        defer rows.Close()
+
+       var sequences []string
        for rows.Next() {
                var sequence string
                if err = rows.Scan(&sequence); err != nil {
@@ -95,18 +110,22 @@ func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) {
                }
                sequences = append(sequences, sequence)
        }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
        return sequences, nil
 }
 
-func (*PostgreSQL) getNonDeferrableConstraints(db *sql.DB) ([]pgConstraint, error) {
+func (*PostgreSQL) getNonDeferrableConstraints(q queryable) ([]pgConstraint, error) {
        var constraints []pgConstraint
 
        sql := `
-SELECT table_name, constraint_name
-FROM information_schema.table_constraints
-WHERE constraint_type = 'FOREIGN KEY'
-  AND is_deferrable = 'NO'`
-       rows, err := db.Query(sql)
+               SELECT table_schema || '.' || table_name, constraint_name
+               FROM information_schema.table_constraints
+               WHERE constraint_type = 'FOREIGN KEY'
+                 AND is_deferrable = 'NO'
+       `
+       rows, err := q.Query(sql)
        if err != nil {
                return nil, err
        }
@@ -114,23 +133,27 @@ WHERE constraint_type = 'FOREIGN KEY'
        defer rows.Close()
        for rows.Next() {
                var constraint pgConstraint
-               err = rows.Scan(&constraint.tableName, &constraint.constraintName)
-               if err != nil {
+               if err = rows.Scan(&constraint.tableName, &constraint.constraintName); err != nil {
                        return nil, err
                }
                constraints = append(constraints, constraint)
        }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
        return constraints, nil
 }
 
-func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) (err error) {
        defer func() {
                // re-enable triggers after load
                var sql string
                for _, table := range h.tables {
                        sql += fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL;", h.quoteKeyword(table))
                }
-               db.Exec(sql)
+               if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+                       err = err2
+               }
        }()
 
        tx, err := db.Begin()
@@ -154,14 +177,16 @@ func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error {
        return tx.Commit()
 }
 
-func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) (err error) {
        defer func() {
                // ensure constraint being not deferrable again after load
                var sql string
                for _, constraint := range h.nonDeferrableConstraints {
                        sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
                }
-               db.Exec(sql)
+               if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+                       err = err2
+               }
        }()
 
        var sql string
@@ -176,28 +201,31 @@ func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction)
        if err != nil {
                return err
        }
+       defer tx.Rollback()
 
        if _, err = tx.Exec("SET CONSTRAINTS ALL DEFERRED"); err != nil {
-               return nil
+               return err
        }
 
        if err = loadFn(tx); err != nil {
-               tx.Rollback()
                return err
        }
 
        return tx.Commit()
 }
 
-func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
        // ensure sequences being reset after load
-       defer h.resetSequences(db)
+       defer func() {
+               if err2 := h.resetSequences(db); err2 != nil && err == nil {
+                       err = err2
+               }
+       }()
 
        if h.UseAlterConstraint {
                return h.makeConstraintsDeferrable(db, loadFn)
-       } else {
-               return h.disableTriggers(db, loadFn)
        }
+       return h.disableTriggers(db, loadFn)
 }
 
 func (h *PostgreSQL) resetSequences(db *sql.DB) error {
@@ -209,3 +237,53 @@ func (h *PostgreSQL) resetSequences(db *sql.DB) error {
        }
        return nil
 }
+
+func (h *PostgreSQL) isTableModified(q queryable, tableName string) (bool, error) {
+       checksum, err := h.getChecksum(q, tableName)
+       if err != nil {
+               return false, err
+       }
+
+       oldChecksum := h.tablesChecksum[tableName]
+
+       return oldChecksum == "" || checksum != oldChecksum, nil
+}
+
+func (h *PostgreSQL) afterLoad(q queryable) error {
+       if h.tablesChecksum != nil {
+               return nil
+       }
+
+       h.tablesChecksum = make(map[string]string, len(h.tables))
+       for _, t := range h.tables {
+               checksum, err := h.getChecksum(q, t)
+               if err != nil {
+                       return err
+               }
+               h.tablesChecksum[t] = checksum
+       }
+       return nil
+}
+
+func (h *PostgreSQL) getChecksum(q queryable, tableName string) (string, error) {
+       sqlStr := fmt.Sprintf(`
+                       SELECT md5(CAST((array_agg(t.*)) AS TEXT))
+                       FROM %s AS t
+               `,
+               h.quoteKeyword(tableName),
+       )
+
+       var checksum sql.NullString
+       if err := q.QueryRow(sqlStr).Scan(&checksum); err != nil {
+               return "", err
+       }
+       return checksum.String, nil
+}
+
+func (*PostgreSQL) quoteKeyword(s string) string {
+       parts := strings.Split(s, ".")
+       for i, p := range parts {
+               parts[i] = fmt.Sprintf(`"%s"`, p)
+       }
+       return strings.Join(parts, ".")
+}
index 4d7fa4fdabf3bad74f942afbd879d45352ab33db..150f01410837f0728639ba088903eb00cf338be9 100644 (file)
@@ -14,23 +14,59 @@ func (*SQLite) paramType() int {
        return paramTypeQuestion
 }
 
-func (*SQLite) databaseName(db *sql.DB) (dbName string) {
+func (*SQLite) databaseName(q queryable) (string, error) {
        var seq int
-       var main string
-       db.QueryRow("PRAGMA database_list").Scan(&seq, &main, &dbName)
+       var main, dbName string
+       err := q.QueryRow("PRAGMA database_list").Scan(&seq, &main, &dbName)
+       if err != nil {
+               return "", err
+       }
        dbName = filepath.Base(dbName)
-       return
+       return dbName, nil
 }
 
-func (*SQLite) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
-       tx, err := db.Begin()
+func (*SQLite) tableNames(q queryable) ([]string, error) {
+       query := `
+               SELECT name
+               FROM sqlite_master
+               WHERE type = 'table';
+       `
+       rows, err := q.Query(query)
        if err != nil {
+               return nil, err
+       }
+       defer rows.Close()
+
+       var tables []string
+       for rows.Next() {
+               var table string
+               if err = rows.Scan(&table); err != nil {
+                       return nil, err
+               }
+               tables = append(tables, table)
+       }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
+       return tables, nil
+}
+
+func (*SQLite) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
+       defer func() {
+               if _, err2 := db.Exec("PRAGMA defer_foreign_keys = OFF"); err2 != nil && err == nil {
+                       err = err2
+               }
+       }()
+
+       if _, err = db.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
                return err
        }
 
-       if _, err = tx.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
+       tx, err := db.Begin()
+       if err != nil {
                return err
        }
+       defer tx.Rollback()
 
        if err = loadFn(tx); err != nil {
                return err
index 1399f5c400842a120c28c0b0405e1aeb2e66db5f..d2fc854797c3903e610c613e55fe070ac73acf76 100644 (file)
@@ -3,6 +3,7 @@ package testfixtures
 import (
        "database/sql"
        "fmt"
+       "strings"
 )
 
 // SQLServer is the helper for SQL Server for this package.
@@ -16,7 +17,7 @@ type SQLServer struct {
 func (h *SQLServer) init(db *sql.DB) error {
        var err error
 
-       h.tables, err = h.getTables(db)
+       h.tables, err = h.tableNames(db)
        if err != nil {
                return err
        }
@@ -28,46 +29,62 @@ func (*SQLServer) paramType() int {
        return paramTypeQuestion
 }
 
-func (*SQLServer) quoteKeyword(str string) string {
-       return fmt.Sprintf("[%s]", str)
+func (*SQLServer) quoteKeyword(s string) string {
+       parts := strings.Split(s, ".")
+       for i, p := range parts {
+               parts[i] = fmt.Sprintf(`[%s]`, p)
+       }
+       return strings.Join(parts, ".")
 }
 
-func (*SQLServer) databaseName(db *sql.DB) (dbname string) {
-       db.QueryRow("SELECT DB_NAME()").Scan(&dbname)
-       return
+func (*SQLServer) databaseName(q queryable) (string, error) {
+       var dbName string
+       err := q.QueryRow("SELECT DB_NAME()").Scan(&dbName)
+       return dbName, err
 }
 
-func (*SQLServer) getTables(db *sql.DB) ([]string, error) {
-       rows, err := db.Query("SELECT table_name FROM information_schema.tables")
+func (*SQLServer) tableNames(q queryable) ([]string, error) {
+       rows, err := q.Query("SELECT table_schema + '.' + table_name FROM information_schema.tables")
        if err != nil {
                return nil, err
        }
-
-       tables := make([]string, 0)
        defer rows.Close()
+
+       var tables []string
        for rows.Next() {
                var table string
-               rows.Scan(&table)
+               if err = rows.Scan(&table); err != nil {
+                       return nil, err
+               }
                tables = append(tables, table)
        }
+       if err = rows.Err(); err != nil {
+               return nil, err
+       }
        return tables, nil
 }
 
-func (*SQLServer) tableHasIdentityColumn(tx *sql.Tx, tableName string) bool {
+func (h *SQLServer) tableHasIdentityColumn(q queryable, tableName string) bool {
        sql := `
-SELECT COUNT(*)
-FROM SYS.IDENTITY_COLUMNS
-WHERE OBJECT_NAME(OBJECT_ID) = ?
-`
+               SELECT COUNT(*)
+               FROM SYS.IDENTITY_COLUMNS
+               WHERE OBJECT_ID = OBJECT_ID(?)
+       `
        var count int
-       tx.QueryRow(sql, tableName).Scan(&count)
+       q.QueryRow(sql, h.quoteKeyword(tableName)).Scan(&count)
        return count > 0
 
 }
 
-func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) error {
+func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) (err error) {
        if h.tableHasIdentityColumn(tx, tableName) {
-               defer tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
+               defer func() {
+                       _, err2 := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
+                       if err2 != nil && err == nil {
+                               err = err2
+                       }
+               }()
+
                _, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
                if err != nil {
                        return err
@@ -76,19 +93,19 @@ func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() e
        return fn()
 }
 
-func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
        // ensure the triggers are re-enable after all
        defer func() {
-               sql := ""
+               var sql string
                for _, table := range h.tables {
                        sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table))
                }
-               if _, err := db.Exec(sql); err != nil {
-                       fmt.Printf("Error on re-enabling constraints: %v\n", err)
+               if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+                       err = err2
                }
        }()
 
-       sql := ""
+       var sql string
        for _, table := range h.tables {
                sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table))
        }
@@ -100,11 +117,19 @@ func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction)
        if err != nil {
                return err
        }
+       defer tx.Rollback()
 
        if err = loadFn(tx); err != nil {
-               tx.Rollback()
                return err
        }
 
        return tx.Commit()
 }
+
+// splitter is a batchSplitter interface implementation. We need it for
+// SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...`
+// could not be executed in the same batch.
+// See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches
+func (*SQLServer) splitter() []byte {
+       return []byte("GO\n")
+}
index bd95580f2adc782694d43a3bb98d6f61ad9c49a0..dfc59c1efc77403976120e773063bd0b7f882d4a 100644 (file)
@@ -2,7 +2,6 @@ package testfixtures
 
 import (
        "database/sql"
-       "errors"
        "fmt"
        "io/ioutil"
        "path"
@@ -33,22 +32,10 @@ type insertSQL struct {
 }
 
 var (
-       // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
-       ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
-
-       // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
-       ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
-
-       // ErrKeyIsNotString is returned when a record is not of type string
-       ErrKeyIsNotString = errors.New("Record map key is not string")
-
-       // ErrNotTestDatabase is returned when the database name doesn't contains "test"
-       ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
-
        dbnameRegexp = regexp.MustCompile("(?i)test")
 )
 
-// NewFolder craetes a context for all fixtures in a given folder into the database:
+// NewFolder creates a context for all fixtures in a given folder into the database:
 //     NewFolder(db, &PostgreSQL{}, "my/fixtures/folder")
 func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
        fixtures, err := fixturesFromFolder(folderName)
@@ -64,7 +51,7 @@ func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
        return c, nil
 }
 
-// NewFiles craetes a context for all specified fixtures files into database:
+// NewFiles creates a context for all specified fixtures files into database:
 //     NewFiles(db, &PostgreSQL{},
 //         "fixtures/customers.yml",
 //         "fixtures/orders.yml"
@@ -102,27 +89,55 @@ func newContext(db *sql.DB, helper Helper, fixtures []*fixtureFile) (*Context, e
        return c, nil
 }
 
+// DetectTestDatabase returns nil if databaseName matches regexp
+//     if err := fixtures.DetectTestDatabase(); err != nil {
+//         log.Fatal(err)
+//     }
+func (c *Context) DetectTestDatabase() error {
+       dbName, err := c.helper.databaseName(c.db)
+       if err != nil {
+               return err
+       }
+       if !dbnameRegexp.MatchString(dbName) {
+               return ErrNotTestDatabase
+       }
+       return nil
+}
+
 // Load wipes and after load all fixtures in the database.
 //     if err := fixtures.Load(); err != nil {
 //         log.Fatal(err)
 //     }
 func (c *Context) Load() error {
        if !skipDatabaseNameCheck {
-               if !dbnameRegexp.MatchString(c.helper.databaseName(c.db)) {
-                       return ErrNotTestDatabase
+               if err := c.DetectTestDatabase(); err != nil {
+                       return err
                }
        }
 
        err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error {
                for _, file := range c.fixturesFiles {
+                       modified, err := c.helper.isTableModified(tx, file.fileNameWithoutExtension())
+                       if err != nil {
+                               return err
+                       }
+                       if !modified {
+                               continue
+                       }
                        if err := file.delete(tx, c.helper); err != nil {
                                return err
                        }
 
-                       err := c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
-                               for _, i := range file.insertSQLs {
+                       err = c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
+                               for j, i := range file.insertSQLs {
                                        if _, err := tx.Exec(i.sql, i.params...); err != nil {
-                                               return err
+                                               return &InsertError{
+                                                       Err:    err,
+                                                       File:   file.fileName,
+                                                       Index:  j,
+                                                       SQL:    i.sql,
+                                                       Params: i.params,
+                                               }
                                        }
                                }
                                return nil
@@ -133,7 +148,10 @@ func (c *Context) Load() error {
                }
                return nil
        })
-       return err
+       if err != nil {
+               return err
+       }
+       return c.helper.afterLoad(c.db)
 }
 
 func (c *Context) buildInsertSQLs() error {
@@ -204,25 +222,33 @@ func (f *fixtureFile) buildInsertSQL(h Helper, record map[interface{}]interface{
 
                sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr))
 
+               // if string, try convert to SQL or time
+               // if map or array, convert to json
+               switch v := value.(type) {
+               case string:
+                       if strings.HasPrefix(v, "RAW=") {
+                               sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
+                               continue
+                       }
+
+                       if t, err := tryStrToDate(v); err == nil {
+                               value = t
+                       }
+               case []interface{}, map[interface{}]interface{}:
+                       value = recursiveToJSON(v)
+               }
+
                switch h.paramType() {
                case paramTypeDollar:
                        sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
                case paramTypeQuestion:
                        sqlValues = append(sqlValues, "?")
                case paramTypeColon:
-                       switch {
-                       case isDateTime(value):
-                               sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD HH24:MI:SS')", i))
-                       case isDate(value):
-                               sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD')", i))
-                       case isTime(value):
-                               sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'HH24:MI:SS')", i))
-                       default:
-                               sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
-                       }
+                       sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
                }
-               i++
+
                values = append(values, value)
+               i++
        }
 
        sqlStr = fmt.Sprintf(
index 67967075004783366fb50ce153c8ffa3a525397e..8c5cba1d035f4d6c1ff6cabeb037c80371d303f9 100644 (file)
@@ -1,36 +1,34 @@
 package testfixtures
 
-import "regexp"
-
-var (
-       regexpDate     = regexp.MustCompile("\\d\\d\\d\\d-\\d\\d-\\d\\d")
-       regexpDateTime = regexp.MustCompile("\\d\\d\\d\\d-\\d\\d-\\d\\d \\d\\d:\\d\\d:\\d\\d")
-       regexpTime     = regexp.MustCompile("\\d\\d:\\d\\d:\\d\\d")
+import (
+       "errors"
+       "time"
 )
 
-func isDate(value interface{}) bool {
-       str, isStr := value.(string)
-       if !isStr {
-               return false
-       }
-
-       return regexpDate.MatchString(str)
+var timeFormats = []string{
+       "2006-01-02",
+       "2006-01-02 15:04",
+       "2006-01-02 15:04:05",
+       "20060102",
+       "20060102 15:04",
+       "20060102 15:04:05",
+       "02/01/2006",
+       "02/01/2006 15:04",
+       "02/01/2006 15:04:05",
+       "2006-01-02T15:04-07:00",
+       "2006-01-02T15:04:05-07:00",
 }
 
-func isDateTime(value interface{}) bool {
-       str, isStr := value.(string)
-       if !isStr {
-               return false
-       }
+// ErrCouldNotConvertToTime is returns when a string is not a reconizable time format
+var ErrCouldNotConvertToTime = errors.New("Could not convert string to time")
 
-       return regexpDateTime.MatchString(str)
-}
-
-func isTime(value interface{}) bool {
-       str, isStr := value.(string)
-       if !isStr {
-               return false
+func tryStrToDate(s string) (time.Time, error) {
+       for _, f := range timeFormats {
+               t, err := time.ParseInLocation(f, s, time.Local)
+               if err != nil {
+                       continue
+               }
+               return t, nil
        }
-
-       return regexpTime.MatchString(str)
+       return time.Time{}, ErrCouldNotConvertToTime
 }