diff options
Diffstat (limited to 'vendor/gopkg.in/testfixtures.v2/oracle.go')
-rw-r--r-- | vendor/gopkg.in/testfixtures.v2/oracle.go | 83 |
1 files changed, 61 insertions, 22 deletions
diff --git a/vendor/gopkg.in/testfixtures.v2/oracle.go b/vendor/gopkg.in/testfixtures.v2/oracle.go index 59600ebfc5..af5c92ddb3 100644 --- a/vendor/gopkg.in/testfixtures.v2/oracle.go +++ b/vendor/gopkg.in/testfixtures.v2/oracle.go @@ -43,54 +43,90 @@ func (*Oracle) quoteKeyword(str string) string { return fmt.Sprintf("\"%s\"", strings.ToUpper(str)) } -func (*Oracle) databaseName(db *sql.DB) (dbName string) { - db.QueryRow("SELECT user FROM DUAL").Scan(&dbName) - return +func (*Oracle) databaseName(q queryable) (string, error) { + var dbName string + err := q.QueryRow("SELECT user FROM DUAL").Scan(&dbName) + return dbName, err } -func (*Oracle) getEnabledConstraints(db *sql.DB) ([]oracleConstraint, error) { - constraints := make([]oracleConstraint, 0) - rows, err := db.Query(` - SELECT table_name, constraint_name - FROM user_constraints - WHERE constraint_type = 'R' - AND status = 'ENABLED' - `) +func (*Oracle) tableNames(q queryable) ([]string, error) { + query := ` + SELECT TABLE_NAME + FROM USER_TABLES + ` + rows, err := q.Query(query) if err != nil { return nil, err } defer rows.Close() + + var tables []string + for rows.Next() { + var table string + if err = rows.Scan(&table); err != nil { + return nil, err + } + tables = append(tables, table) + } + if err = rows.Err(); err != nil { + return nil, err + } + return tables, nil + +} + +func (*Oracle) getEnabledConstraints(q queryable) ([]oracleConstraint, error) { + var constraints []oracleConstraint + rows, err := q.Query(` + SELECT table_name, constraint_name + FROM user_constraints + WHERE constraint_type = 'R' + AND status = 'ENABLED' + `) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { var constraint oracleConstraint rows.Scan(&constraint.tableName, &constraint.constraintName) constraints = append(constraints, constraint) } + if err = rows.Err(); err != nil { + return nil, err + } return constraints, nil } -func (*Oracle) getSequences(db *sql.DB) ([]string, error) { - sequences := make([]string, 0) - rows, err := db.Query("SELECT sequence_name FROM user_sequences") +func (*Oracle) getSequences(q queryable) ([]string, error) { + var sequences []string + rows, err := q.Query("SELECT sequence_name FROM user_sequences") if err != nil { return nil, err } - defer rows.Close() + for rows.Next() { var sequence string - rows.Scan(&sequence) + if err = rows.Scan(&sequence); err != nil { + return nil, err + } sequences = append(sequences, sequence) } + if err = rows.Err(); err != nil { + return nil, err + } return sequences, nil } -func (h *Oracle) resetSequences(db *sql.DB) error { +func (h *Oracle) resetSequences(q queryable) error { for _, sequence := range h.sequences { - _, err := db.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence))) + _, err := q.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence))) if err != nil { return err } - _, err = db.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo)) + _, err = q.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo)) if err != nil { return err } @@ -98,11 +134,14 @@ func (h *Oracle) resetSequences(db *sql.DB) error { return nil } -func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error { +func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) { // re-enable after load defer func() { for _, c := range h.enabledConstraints { - db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName))) + _, err2 := db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName))) + if err2 != nil && err == nil { + err = err2 + } } }() @@ -118,9 +157,9 @@ func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) er if err != nil { return err } + defer tx.Rollback() if err = loadFn(tx); err != nil { - tx.Rollback() return err } |