You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sqlserver.go 3.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. )
  7. // SQLServer is the helper for SQL Server for this package.
  8. // SQL Server >= 2008 is required.
  9. type SQLServer struct {
  10. baseHelper
  11. tables []string
  12. }
  13. func (h *SQLServer) init(db *sql.DB) error {
  14. var err error
  15. h.tables, err = h.tableNames(db)
  16. if err != nil {
  17. return err
  18. }
  19. return nil
  20. }
  21. func (*SQLServer) paramType() int {
  22. return paramTypeQuestion
  23. }
  24. func (*SQLServer) quoteKeyword(s string) string {
  25. parts := strings.Split(s, ".")
  26. for i, p := range parts {
  27. parts[i] = fmt.Sprintf(`[%s]`, p)
  28. }
  29. return strings.Join(parts, ".")
  30. }
  31. func (*SQLServer) databaseName(q queryable) (string, error) {
  32. var dbName string
  33. err := q.QueryRow("SELECT DB_NAME()").Scan(&dbName)
  34. return dbName, err
  35. }
  36. func (*SQLServer) tableNames(q queryable) ([]string, error) {
  37. rows, err := q.Query("SELECT table_schema + '.' + table_name FROM information_schema.tables")
  38. if err != nil {
  39. return nil, err
  40. }
  41. defer rows.Close()
  42. var tables []string
  43. for rows.Next() {
  44. var table string
  45. if err = rows.Scan(&table); err != nil {
  46. return nil, err
  47. }
  48. tables = append(tables, table)
  49. }
  50. if err = rows.Err(); err != nil {
  51. return nil, err
  52. }
  53. return tables, nil
  54. }
  55. func (h *SQLServer) tableHasIdentityColumn(q queryable, tableName string) bool {
  56. sql := `
  57. SELECT COUNT(*)
  58. FROM SYS.IDENTITY_COLUMNS
  59. WHERE OBJECT_ID = OBJECT_ID(?)
  60. `
  61. var count int
  62. q.QueryRow(sql, h.quoteKeyword(tableName)).Scan(&count)
  63. return count > 0
  64. }
  65. func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) (err error) {
  66. if h.tableHasIdentityColumn(tx, tableName) {
  67. defer func() {
  68. _, err2 := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
  69. if err2 != nil && err == nil {
  70. err = err2
  71. }
  72. }()
  73. _, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
  74. if err != nil {
  75. return err
  76. }
  77. }
  78. return fn()
  79. }
  80. func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
  81. // ensure the triggers are re-enable after all
  82. defer func() {
  83. var sql string
  84. for _, table := range h.tables {
  85. sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table))
  86. }
  87. if _, err2 := db.Exec(sql); err2 != nil && err == nil {
  88. err = err2
  89. }
  90. }()
  91. var sql string
  92. for _, table := range h.tables {
  93. sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table))
  94. }
  95. if _, err := db.Exec(sql); err != nil {
  96. return err
  97. }
  98. tx, err := db.Begin()
  99. if err != nil {
  100. return err
  101. }
  102. defer tx.Rollback()
  103. if err = loadFn(tx); err != nil {
  104. return err
  105. }
  106. return tx.Commit()
  107. }
  108. // splitter is a batchSplitter interface implementation. We need it for
  109. // SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...`
  110. // could not be executed in the same batch.
  111. // See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches
  112. func (*SQLServer) splitter() []byte {
  113. return []byte("GO\n")
  114. }