version = "v2.3.2"
[[projects]]
- digest = "1:e2144032dcf8e856fb733151391669dc2d32d79ebb69e3c4151efa605f5e8a01"
+ digest = "1:9c541fc507676a69ea8aaed1af53278a5241d26ce0f192c993fec2ac5b78f795"
name = "gopkg.in/testfixtures.v2"
packages = ["."]
pruneopts = "NUT"
- revision = "b9ef14dc461bf934d8df2dfc6f1f456be5664cca"
- version = "v2.0.0"
+ revision = "fa3fb89109b0b31957a5430cef3e93e535de362b"
+ version = "v2.5.0"
[[projects]]
digest = "1:ad6f94355d292690137613735965bd3688844880fdab90eccf66321910344942"
)
type (
- DataBaseHelper Helper // Deprecated: Use Helper instead
+ // DataBaseHelper is the helper interface
+ // Deprecated: Use Helper instead
+ DataBaseHelper Helper
- PostgreSQLHelper struct { // Deprecated: Use PostgreSQL{} instead
+ // PostgreSQLHelper is the PostgreSQL helper
+ // Deprecated: Use PostgreSQL{} instead
+ PostgreSQLHelper struct {
PostgreSQL
UseAlterConstraint bool
}
- MySQLHelper struct { // Deprecated: Use MySQL{} instead
+
+ // MySQLHelper is the MySQL helper
+ // Deprecated: Use MySQL{} instead
+ MySQLHelper struct {
MySQL
}
- SQLiteHelper struct { // Deprecated: Use SQLite{} instead
+
+ // SQLiteHelper is the SQLite helper
+ // Deprecated: Use SQLite{} instead
+ SQLiteHelper struct {
SQLite
}
- SQLServerHelper struct { // Deprecated: Use SQLServer{} instead
+
+ // SQLServerHelper is the SQLServer helper
+ // Deprecated: Use SQLServer{} instead
+ SQLServerHelper struct {
SQLServer
}
- OracleHelper struct { // Deprecated: Use Oracle{} instead
+
+ // OracleHelper is the Oracle helper
+ // Deprecated: Use Oracle{} instead
+ OracleHelper struct {
Oracle
}
)
--- /dev/null
+package testfixtures
+
+import (
+ "errors"
+ "fmt"
+)
+
+var (
+ // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
+ ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
+
+ // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
+ ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
+
+ // ErrKeyIsNotString is returned when a record is not of type string
+ ErrKeyIsNotString = errors.New("Record map key is not string")
+
+ // ErrNotTestDatabase is returned when the database name doesn't contains "test"
+ ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
+)
+
+// InsertError will be returned if any error happens on database while
+// inserting the record
+type InsertError struct {
+ Err error
+ File string
+ Index int
+ SQL string
+ Params []interface{}
+}
+
+func (e *InsertError) Error() string {
+ return fmt.Sprintf(
+ "testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",
+ e.Err,
+ e.File,
+ e.Index,
+ e.SQL,
+ e.Params,
+ )
+}
--- /dev/null
+package testfixtures
+
+import (
+ "database/sql"
+ "fmt"
+ "os"
+ "path"
+ "unicode/utf8"
+
+ "gopkg.in/yaml.v2"
+)
+
+// TableInfo is settings for generating a fixture for table.
+type TableInfo struct {
+ Name string // Table name
+ Where string // A condition for extracting records. If this value is empty, extracts all records.
+}
+
+func (ti *TableInfo) whereClause() string {
+ if ti.Where == "" {
+ return ""
+ }
+ return fmt.Sprintf(" WHERE %s", ti.Where)
+}
+
+// GenerateFixtures generates fixtures for the current contents of a database, and saves
+// them to the specified directory
+func GenerateFixtures(db *sql.DB, helper Helper, dir string) error {
+ tables, err := helper.tableNames(db)
+ if err != nil {
+ return err
+ }
+ for _, table := range tables {
+ filename := path.Join(dir, table+".yml")
+ if err := generateFixturesForTable(db, helper, &TableInfo{Name: table}, filename); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// GenerateFixturesForTables generates fixtures for the current contents of specified tables in a database, and saves
+// them to the specified directory
+func GenerateFixturesForTables(db *sql.DB, tables []*TableInfo, helper Helper, dir string) error {
+ for _, table := range tables {
+ filename := path.Join(dir, table.Name+".yml")
+ if err := generateFixturesForTable(db, helper, table, filename); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func generateFixturesForTable(db *sql.DB, h Helper, table *TableInfo, filename string) error {
+ query := fmt.Sprintf("SELECT * FROM %s%s", h.quoteKeyword(table.Name), table.whereClause())
+ rows, err := db.Query(query)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ columns, err := rows.Columns()
+ if err != nil {
+ return err
+ }
+
+ fixtures := make([]interface{}, 0, 10)
+ for rows.Next() {
+ entries := make([]interface{}, len(columns))
+ entryPtrs := make([]interface{}, len(entries))
+ for i := range entries {
+ entryPtrs[i] = &entries[i]
+ }
+ if err := rows.Scan(entryPtrs...); err != nil {
+ return err
+ }
+
+ entryMap := make(map[string]interface{}, len(entries))
+ for i, column := range columns {
+ entryMap[column] = convertValue(entries[i])
+ }
+ fixtures = append(fixtures, entryMap)
+ }
+ if err = rows.Err(); err != nil {
+ return err
+ }
+
+ f, err := os.Create(filename)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ marshaled, err := yaml.Marshal(fixtures)
+ if err != nil {
+ return err
+ }
+ _, err = f.Write(marshaled)
+ return err
+}
+
+func convertValue(value interface{}) interface{} {
+ switch v := value.(type) {
+ case []byte:
+ if utf8.Valid(v) {
+ return string(v)
+ }
+ }
+ return value
+}
init(*sql.DB) error
disableReferentialIntegrity(*sql.DB, loadFunction) error
paramType() int
- databaseName(*sql.DB) string
+ databaseName(queryable) (string, error)
+ tableNames(queryable) ([]string, error)
+ isTableModified(queryable, string) (bool, error)
+ afterLoad(queryable) error
quoteKeyword(string) string
whileInsertOnTable(*sql.Tx, string, func() error) error
}
+type queryable interface {
+ Exec(string, ...interface{}) (sql.Result, error)
+ Query(string, ...interface{}) (*sql.Rows, error)
+ QueryRow(string, ...interface{}) *sql.Row
+}
+
+// batchSplitter is an interface with method which returns byte slice for
+// splitting SQL batches. This need to split sql statements and run its
+// separately.
+//
+// For Microsoft SQL Server batch splitter is "GO". For details see
+// https://docs.microsoft.com/en-us/sql/t-sql/language-elements/sql-server-utilities-statements-go
+type batchSplitter interface {
+ splitter() []byte
+}
+
+var (
+ _ Helper = &MySQL{}
+ _ Helper = &PostgreSQL{}
+ _ Helper = &SQLite{}
+ _ Helper = &Oracle{}
+ _ Helper = &SQLServer{}
+)
+
type baseHelper struct{}
-func (*baseHelper) init(_ *sql.DB) error {
+func (baseHelper) init(_ *sql.DB) error {
return nil
}
-func (*baseHelper) quoteKeyword(str string) string {
+func (baseHelper) quoteKeyword(str string) string {
return fmt.Sprintf(`"%s"`, str)
}
-func (*baseHelper) whileInsertOnTable(_ *sql.Tx, _ string, fn func() error) error {
+func (baseHelper) whileInsertOnTable(_ *sql.Tx, _ string, fn func() error) error {
return fn()
}
+
+func (baseHelper) isTableModified(_ queryable, _ string) (bool, error) {
+ return true, nil
+}
+
+func (baseHelper) afterLoad(_ queryable) error {
+ return nil
+}
--- /dev/null
+package testfixtures
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+)
+
+var (
+ _ driver.Valuer = jsonArray{}
+ _ driver.Valuer = jsonMap{}
+)
+
+type jsonArray []interface{}
+
+func (a jsonArray) Value() (driver.Value, error) {
+ return json.Marshal(a)
+}
+
+type jsonMap map[string]interface{}
+
+func (m jsonMap) Value() (driver.Value, error) {
+ return json.Marshal(m)
+}
+
+// Go refuses to convert map[interface{}]interface{} to JSON because JSON only support string keys
+// So it's necessary to recursively convert all map[interface]interface{} to map[string]interface{}
+func recursiveToJSON(v interface{}) (r interface{}) {
+ switch v := v.(type) {
+ case []interface{}:
+ for i, e := range v {
+ v[i] = recursiveToJSON(e)
+ }
+ r = jsonArray(v)
+ case map[interface{}]interface{}:
+ newMap := make(map[string]interface{}, len(v))
+ for k, e := range v {
+ newMap[k.(string)] = recursiveToJSON(e)
+ }
+ r = jsonMap(newMap)
+ default:
+ r = v
+ }
+ return
+}
// MySQL is the MySQL helper for this package
type MySQL struct {
baseHelper
+ tables []string
+ tablesChecksum map[string]int64
+}
+
+func (h *MySQL) init(db *sql.DB) error {
+ var err error
+ h.tables, err = h.tableNames(db)
+ if err != nil {
+ return err
+ }
+
+ return nil
}
func (*MySQL) paramType() int {
return fmt.Sprintf("`%s`", str)
}
-func (*MySQL) databaseName(db *sql.DB) (dbName string) {
- db.QueryRow("SELECT DATABASE()").Scan(&dbName)
- return
+func (*MySQL) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT DATABASE()").Scan(&dbName)
+ return dbName, err
}
-func (h *MySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *MySQL) tableNames(q queryable) ([]string, error) {
+ query := `
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = ?
+ AND table_type = 'BASE TABLE';
+ `
+ dbName, err := h.databaseName(q)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := q.Query(query, dbName)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var tables []string
+ for rows.Next() {
+ var table string
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
+ tables = append(tables, table)
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return tables, nil
+
+}
+
+func (h *MySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// re-enable after load
- defer db.Exec("SET FOREIGN_KEY_CHECKS = 1")
+ defer func() {
+ if _, err2 := db.Exec("SET FOREIGN_KEY_CHECKS = 1"); err2 != nil && err == nil {
+ err = err2
+ }
+ }()
tx, err := db.Begin()
if err != nil {
return err
}
+ defer tx.Rollback()
if _, err = tx.Exec("SET FOREIGN_KEY_CHECKS = 0"); err != nil {
return err
}
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
return tx.Commit()
}
+
+func (h *MySQL) isTableModified(q queryable, tableName string) (bool, error) {
+ checksum, err := h.getChecksum(q, tableName)
+ if err != nil {
+ return true, err
+ }
+
+ oldChecksum := h.tablesChecksum[tableName]
+
+ return oldChecksum == 0 || checksum != oldChecksum, nil
+}
+
+func (h *MySQL) afterLoad(q queryable) error {
+ if h.tablesChecksum != nil {
+ return nil
+ }
+
+ h.tablesChecksum = make(map[string]int64, len(h.tables))
+ for _, t := range h.tables {
+ checksum, err := h.getChecksum(q, t)
+ if err != nil {
+ return err
+ }
+ h.tablesChecksum[t] = checksum
+ }
+ return nil
+}
+
+func (h *MySQL) getChecksum(q queryable, tableName string) (int64, error) {
+ sql := fmt.Sprintf("CHECKSUM TABLE %s", h.quoteKeyword(tableName))
+ var (
+ table string
+ checksum int64
+ )
+ if err := q.QueryRow(sql).Scan(&table, &checksum); err != nil {
+ return 0, err
+ }
+ return checksum, nil
+}
return fmt.Sprintf("\"%s\"", strings.ToUpper(str))
}
-func (*Oracle) databaseName(db *sql.DB) (dbName string) {
- db.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
- return
+func (*Oracle) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
+ return dbName, err
}
-func (*Oracle) getEnabledConstraints(db *sql.DB) ([]oracleConstraint, error) {
- constraints := make([]oracleConstraint, 0)
- rows, err := db.Query(`
- SELECT table_name, constraint_name
- FROM user_constraints
- WHERE constraint_type = 'R'
- AND status = 'ENABLED'
- `)
+func (*Oracle) tableNames(q queryable) ([]string, error) {
+ query := `
+ SELECT TABLE_NAME
+ FROM USER_TABLES
+ `
+ rows, err := q.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
+
+ var tables []string
+ for rows.Next() {
+ var table string
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
+ tables = append(tables, table)
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return tables, nil
+
+}
+
+func (*Oracle) getEnabledConstraints(q queryable) ([]oracleConstraint, error) {
+ var constraints []oracleConstraint
+ rows, err := q.Query(`
+ SELECT table_name, constraint_name
+ FROM user_constraints
+ WHERE constraint_type = 'R'
+ AND status = 'ENABLED'
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
for rows.Next() {
var constraint oracleConstraint
rows.Scan(&constraint.tableName, &constraint.constraintName)
constraints = append(constraints, constraint)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return constraints, nil
}
-func (*Oracle) getSequences(db *sql.DB) ([]string, error) {
- sequences := make([]string, 0)
- rows, err := db.Query("SELECT sequence_name FROM user_sequences")
+func (*Oracle) getSequences(q queryable) ([]string, error) {
+ var sequences []string
+ rows, err := q.Query("SELECT sequence_name FROM user_sequences")
if err != nil {
return nil, err
}
-
defer rows.Close()
+
for rows.Next() {
var sequence string
- rows.Scan(&sequence)
+ if err = rows.Scan(&sequence); err != nil {
+ return nil, err
+ }
sequences = append(sequences, sequence)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return sequences, nil
}
-func (h *Oracle) resetSequences(db *sql.DB) error {
+func (h *Oracle) resetSequences(q queryable) error {
for _, sequence := range h.sequences {
- _, err := db.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
+ _, err := q.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
if err != nil {
return err
}
- _, err = db.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
+ _, err = q.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
if err != nil {
return err
}
return nil
}
-func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// re-enable after load
defer func() {
for _, c := range h.enabledConstraints {
- db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
+ _, err2 := db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
+ if err2 != nil && err == nil {
+ err = err2
+ }
}
}()
if err != nil {
return err
}
+ defer tx.Rollback()
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
import (
"database/sql"
"fmt"
+ "strings"
)
// PostgreSQL is the PG helper for this package
tables []string
sequences []string
nonDeferrableConstraints []pgConstraint
+ tablesChecksum map[string]string
}
type pgConstraint struct {
func (h *PostgreSQL) init(db *sql.DB) error {
var err error
- h.tables, err = h.getTables(db)
+ h.tables, err = h.tableNames(db)
if err != nil {
return err
}
return paramTypeDollar
}
-func (*PostgreSQL) databaseName(db *sql.DB) (dbName string) {
- db.QueryRow("SELECT current_database()").Scan(&dbName)
- return
+func (*PostgreSQL) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT current_database()").Scan(&dbName)
+ return dbName, err
}
-func (h *PostgreSQL) getTables(db *sql.DB) ([]string, error) {
+func (h *PostgreSQL) tableNames(q queryable) ([]string, error) {
var tables []string
sql := `
-SELECT table_name
-FROM information_schema.tables
-WHERE table_schema = 'public'
- AND table_type = 'BASE TABLE';
-`
- rows, err := db.Query(sql)
+ SELECT pg_namespace.nspname || '.' || pg_class.relname
+ FROM pg_class
+ INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
+ WHERE pg_class.relkind = 'r'
+ AND pg_namespace.nspname NOT IN ('pg_catalog', 'information_schema')
+ AND pg_namespace.nspname NOT LIKE 'pg_toast%';
+ `
+ rows, err := q.Query(sql)
if err != nil {
return nil, err
}
-
defer rows.Close()
+
for rows.Next() {
var table string
- rows.Scan(&table)
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
tables = append(tables, table)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return tables, nil
}
-func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) {
- var sequences []string
+func (h *PostgreSQL) getSequences(q queryable) ([]string, error) {
+ const sql = `
+ SELECT pg_namespace.nspname || '.' || pg_class.relname AS sequence_name
+ FROM pg_class
+ INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
+ WHERE pg_class.relkind = 'S'
+ `
- sql := "SELECT relname FROM pg_class WHERE relkind = 'S'"
- rows, err := db.Query(sql)
+ rows, err := q.Query(sql)
if err != nil {
return nil, err
}
-
defer rows.Close()
+
+ var sequences []string
for rows.Next() {
var sequence string
if err = rows.Scan(&sequence); err != nil {
}
sequences = append(sequences, sequence)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return sequences, nil
}
-func (*PostgreSQL) getNonDeferrableConstraints(db *sql.DB) ([]pgConstraint, error) {
+func (*PostgreSQL) getNonDeferrableConstraints(q queryable) ([]pgConstraint, error) {
var constraints []pgConstraint
sql := `
-SELECT table_name, constraint_name
-FROM information_schema.table_constraints
-WHERE constraint_type = 'FOREIGN KEY'
- AND is_deferrable = 'NO'`
- rows, err := db.Query(sql)
+ SELECT table_schema || '.' || table_name, constraint_name
+ FROM information_schema.table_constraints
+ WHERE constraint_type = 'FOREIGN KEY'
+ AND is_deferrable = 'NO'
+ `
+ rows, err := q.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var constraint pgConstraint
- err = rows.Scan(&constraint.tableName, &constraint.constraintName)
- if err != nil {
+ if err = rows.Scan(&constraint.tableName, &constraint.constraintName); err != nil {
return nil, err
}
constraints = append(constraints, constraint)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return constraints, nil
}
-func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) (err error) {
defer func() {
// re-enable triggers after load
var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL;", h.quoteKeyword(table))
}
- db.Exec(sql)
+ if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+ err = err2
+ }
}()
tx, err := db.Begin()
return tx.Commit()
}
-func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) (err error) {
defer func() {
// ensure constraint being not deferrable again after load
var sql string
for _, constraint := range h.nonDeferrableConstraints {
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
}
- db.Exec(sql)
+ if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+ err = err2
+ }
}()
var sql string
if err != nil {
return err
}
+ defer tx.Rollback()
if _, err = tx.Exec("SET CONSTRAINTS ALL DEFERRED"); err != nil {
- return nil
+ return err
}
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
return tx.Commit()
}
-func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// ensure sequences being reset after load
- defer h.resetSequences(db)
+ defer func() {
+ if err2 := h.resetSequences(db); err2 != nil && err == nil {
+ err = err2
+ }
+ }()
if h.UseAlterConstraint {
return h.makeConstraintsDeferrable(db, loadFn)
- } else {
- return h.disableTriggers(db, loadFn)
}
+ return h.disableTriggers(db, loadFn)
}
func (h *PostgreSQL) resetSequences(db *sql.DB) error {
}
return nil
}
+
+func (h *PostgreSQL) isTableModified(q queryable, tableName string) (bool, error) {
+ checksum, err := h.getChecksum(q, tableName)
+ if err != nil {
+ return false, err
+ }
+
+ oldChecksum := h.tablesChecksum[tableName]
+
+ return oldChecksum == "" || checksum != oldChecksum, nil
+}
+
+func (h *PostgreSQL) afterLoad(q queryable) error {
+ if h.tablesChecksum != nil {
+ return nil
+ }
+
+ h.tablesChecksum = make(map[string]string, len(h.tables))
+ for _, t := range h.tables {
+ checksum, err := h.getChecksum(q, t)
+ if err != nil {
+ return err
+ }
+ h.tablesChecksum[t] = checksum
+ }
+ return nil
+}
+
+func (h *PostgreSQL) getChecksum(q queryable, tableName string) (string, error) {
+ sqlStr := fmt.Sprintf(`
+ SELECT md5(CAST((array_agg(t.*)) AS TEXT))
+ FROM %s AS t
+ `,
+ h.quoteKeyword(tableName),
+ )
+
+ var checksum sql.NullString
+ if err := q.QueryRow(sqlStr).Scan(&checksum); err != nil {
+ return "", err
+ }
+ return checksum.String, nil
+}
+
+func (*PostgreSQL) quoteKeyword(s string) string {
+ parts := strings.Split(s, ".")
+ for i, p := range parts {
+ parts[i] = fmt.Sprintf(`"%s"`, p)
+ }
+ return strings.Join(parts, ".")
+}
return paramTypeQuestion
}
-func (*SQLite) databaseName(db *sql.DB) (dbName string) {
+func (*SQLite) databaseName(q queryable) (string, error) {
var seq int
- var main string
- db.QueryRow("PRAGMA database_list").Scan(&seq, &main, &dbName)
+ var main, dbName string
+ err := q.QueryRow("PRAGMA database_list").Scan(&seq, &main, &dbName)
+ if err != nil {
+ return "", err
+ }
dbName = filepath.Base(dbName)
- return
+ return dbName, nil
}
-func (*SQLite) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
- tx, err := db.Begin()
+func (*SQLite) tableNames(q queryable) ([]string, error) {
+ query := `
+ SELECT name
+ FROM sqlite_master
+ WHERE type = 'table';
+ `
+ rows, err := q.Query(query)
if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var tables []string
+ for rows.Next() {
+ var table string
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
+ tables = append(tables, table)
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return tables, nil
+}
+
+func (*SQLite) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
+ defer func() {
+ if _, err2 := db.Exec("PRAGMA defer_foreign_keys = OFF"); err2 != nil && err == nil {
+ err = err2
+ }
+ }()
+
+ if _, err = db.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
return err
}
- if _, err = tx.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
+ tx, err := db.Begin()
+ if err != nil {
return err
}
+ defer tx.Rollback()
if err = loadFn(tx); err != nil {
return err
import (
"database/sql"
"fmt"
+ "strings"
)
// SQLServer is the helper for SQL Server for this package.
func (h *SQLServer) init(db *sql.DB) error {
var err error
- h.tables, err = h.getTables(db)
+ h.tables, err = h.tableNames(db)
if err != nil {
return err
}
return paramTypeQuestion
}
-func (*SQLServer) quoteKeyword(str string) string {
- return fmt.Sprintf("[%s]", str)
+func (*SQLServer) quoteKeyword(s string) string {
+ parts := strings.Split(s, ".")
+ for i, p := range parts {
+ parts[i] = fmt.Sprintf(`[%s]`, p)
+ }
+ return strings.Join(parts, ".")
}
-func (*SQLServer) databaseName(db *sql.DB) (dbname string) {
- db.QueryRow("SELECT DB_NAME()").Scan(&dbname)
- return
+func (*SQLServer) databaseName(q queryable) (string, error) {
+ var dbName string
+ err := q.QueryRow("SELECT DB_NAME()").Scan(&dbName)
+ return dbName, err
}
-func (*SQLServer) getTables(db *sql.DB) ([]string, error) {
- rows, err := db.Query("SELECT table_name FROM information_schema.tables")
+func (*SQLServer) tableNames(q queryable) ([]string, error) {
+ rows, err := q.Query("SELECT table_schema + '.' + table_name FROM information_schema.tables")
if err != nil {
return nil, err
}
-
- tables := make([]string, 0)
defer rows.Close()
+
+ var tables []string
for rows.Next() {
var table string
- rows.Scan(&table)
+ if err = rows.Scan(&table); err != nil {
+ return nil, err
+ }
tables = append(tables, table)
}
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
return tables, nil
}
-func (*SQLServer) tableHasIdentityColumn(tx *sql.Tx, tableName string) bool {
+func (h *SQLServer) tableHasIdentityColumn(q queryable, tableName string) bool {
sql := `
-SELECT COUNT(*)
-FROM SYS.IDENTITY_COLUMNS
-WHERE OBJECT_NAME(OBJECT_ID) = ?
-`
+ SELECT COUNT(*)
+ FROM SYS.IDENTITY_COLUMNS
+ WHERE OBJECT_ID = OBJECT_ID(?)
+ `
var count int
- tx.QueryRow(sql, tableName).Scan(&count)
+ q.QueryRow(sql, h.quoteKeyword(tableName)).Scan(&count)
return count > 0
}
-func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) error {
+func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) (err error) {
if h.tableHasIdentityColumn(tx, tableName) {
- defer tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
+ defer func() {
+ _, err2 := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
+ if err2 != nil && err == nil {
+ err = err2
+ }
+ }()
+
_, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
if err != nil {
return err
return fn()
}
-func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
+func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// ensure the triggers are re-enable after all
defer func() {
- sql := ""
+ var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table))
}
- if _, err := db.Exec(sql); err != nil {
- fmt.Printf("Error on re-enabling constraints: %v\n", err)
+ if _, err2 := db.Exec(sql); err2 != nil && err == nil {
+ err = err2
}
}()
- sql := ""
+ var sql string
for _, table := range h.tables {
sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table))
}
if err != nil {
return err
}
+ defer tx.Rollback()
if err = loadFn(tx); err != nil {
- tx.Rollback()
return err
}
return tx.Commit()
}
+
+// splitter is a batchSplitter interface implementation. We need it for
+// SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...`
+// could not be executed in the same batch.
+// See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches
+func (*SQLServer) splitter() []byte {
+ return []byte("GO\n")
+}
import (
"database/sql"
- "errors"
"fmt"
"io/ioutil"
"path"
}
var (
- // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
- ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
-
- // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
- ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
-
- // ErrKeyIsNotString is returned when a record is not of type string
- ErrKeyIsNotString = errors.New("Record map key is not string")
-
- // ErrNotTestDatabase is returned when the database name doesn't contains "test"
- ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
-
dbnameRegexp = regexp.MustCompile("(?i)test")
)
-// NewFolder craetes a context for all fixtures in a given folder into the database:
+// NewFolder creates a context for all fixtures in a given folder into the database:
// NewFolder(db, &PostgreSQL{}, "my/fixtures/folder")
func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
fixtures, err := fixturesFromFolder(folderName)
return c, nil
}
-// NewFiles craetes a context for all specified fixtures files into database:
+// NewFiles creates a context for all specified fixtures files into database:
// NewFiles(db, &PostgreSQL{},
// "fixtures/customers.yml",
// "fixtures/orders.yml"
return c, nil
}
+// DetectTestDatabase returns nil if databaseName matches regexp
+// if err := fixtures.DetectTestDatabase(); err != nil {
+// log.Fatal(err)
+// }
+func (c *Context) DetectTestDatabase() error {
+ dbName, err := c.helper.databaseName(c.db)
+ if err != nil {
+ return err
+ }
+ if !dbnameRegexp.MatchString(dbName) {
+ return ErrNotTestDatabase
+ }
+ return nil
+}
+
// Load wipes and after load all fixtures in the database.
// if err := fixtures.Load(); err != nil {
// log.Fatal(err)
// }
func (c *Context) Load() error {
if !skipDatabaseNameCheck {
- if !dbnameRegexp.MatchString(c.helper.databaseName(c.db)) {
- return ErrNotTestDatabase
+ if err := c.DetectTestDatabase(); err != nil {
+ return err
}
}
err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error {
for _, file := range c.fixturesFiles {
+ modified, err := c.helper.isTableModified(tx, file.fileNameWithoutExtension())
+ if err != nil {
+ return err
+ }
+ if !modified {
+ continue
+ }
if err := file.delete(tx, c.helper); err != nil {
return err
}
- err := c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
- for _, i := range file.insertSQLs {
+ err = c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
+ for j, i := range file.insertSQLs {
if _, err := tx.Exec(i.sql, i.params...); err != nil {
- return err
+ return &InsertError{
+ Err: err,
+ File: file.fileName,
+ Index: j,
+ SQL: i.sql,
+ Params: i.params,
+ }
}
}
return nil
}
return nil
})
- return err
+ if err != nil {
+ return err
+ }
+ return c.helper.afterLoad(c.db)
}
func (c *Context) buildInsertSQLs() error {
sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr))
+ // if string, try convert to SQL or time
+ // if map or array, convert to json
+ switch v := value.(type) {
+ case string:
+ if strings.HasPrefix(v, "RAW=") {
+ sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
+ continue
+ }
+
+ if t, err := tryStrToDate(v); err == nil {
+ value = t
+ }
+ case []interface{}, map[interface{}]interface{}:
+ value = recursiveToJSON(v)
+ }
+
switch h.paramType() {
case paramTypeDollar:
sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
case paramTypeQuestion:
sqlValues = append(sqlValues, "?")
case paramTypeColon:
- switch {
- case isDateTime(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD HH24:MI:SS')", i))
- case isDate(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD')", i))
- case isTime(value):
- sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'HH24:MI:SS')", i))
- default:
- sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
- }
+ sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
}
- i++
+
values = append(values, value)
+ i++
}
sqlStr = fmt.Sprintf(
package testfixtures
-import "regexp"
-
-var (
- regexpDate = regexp.MustCompile("\\d\\d\\d\\d-\\d\\d-\\d\\d")
- regexpDateTime = regexp.MustCompile("\\d\\d\\d\\d-\\d\\d-\\d\\d \\d\\d:\\d\\d:\\d\\d")
- regexpTime = regexp.MustCompile("\\d\\d:\\d\\d:\\d\\d")
+import (
+ "errors"
+ "time"
)
-func isDate(value interface{}) bool {
- str, isStr := value.(string)
- if !isStr {
- return false
- }
-
- return regexpDate.MatchString(str)
+var timeFormats = []string{
+ "2006-01-02",
+ "2006-01-02 15:04",
+ "2006-01-02 15:04:05",
+ "20060102",
+ "20060102 15:04",
+ "20060102 15:04:05",
+ "02/01/2006",
+ "02/01/2006 15:04",
+ "02/01/2006 15:04:05",
+ "2006-01-02T15:04-07:00",
+ "2006-01-02T15:04:05-07:00",
}
-func isDateTime(value interface{}) bool {
- str, isStr := value.(string)
- if !isStr {
- return false
- }
+// ErrCouldNotConvertToTime is returns when a string is not a reconizable time format
+var ErrCouldNotConvertToTime = errors.New("Could not convert string to time")
- return regexpDateTime.MatchString(str)
-}
-
-func isTime(value interface{}) bool {
- str, isStr := value.(string)
- if !isStr {
- return false
+func tryStrToDate(s string) (time.Time, error) {
+ for _, f := range timeFormats {
+ t, err := time.ParseInLocation(f, s, time.Local)
+ if err != nil {
+ continue
+ }
+ return t, nil
}
-
- return regexpTime.MatchString(str)
+ return time.Time{}, ErrCouldNotConvertToTime
}