You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sqlserver.go 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. )
  7. type sqlserver struct {
  8. baseHelper
  9. paramTypeCache int
  10. tables []string
  11. }
  12. func (h *sqlserver) init(db *sql.DB) error {
  13. var err error
  14. // NOTE(@andreynering): The SQL Server lib (github.com/denisenkom/go-mssqldb)
  15. // supports both the "?" style (when using the deprecated "mssql" driver)
  16. // and the "@p1" style (when using the new "sqlserver" driver).
  17. //
  18. // Since we don't have a way to know which driver it's been used,
  19. // this is a small hack to detect the allowed param style.
  20. var v int
  21. if err := db.QueryRow("SELECT ?", 1).Scan(&v); err == nil && v == 1 {
  22. h.paramTypeCache = paramTypeQuestion
  23. } else {
  24. h.paramTypeCache = paramTypeAtSign
  25. }
  26. h.tables, err = h.tableNames(db)
  27. if err != nil {
  28. return err
  29. }
  30. return nil
  31. }
  32. func (h *sqlserver) paramType() int {
  33. return h.paramTypeCache
  34. }
  35. func (*sqlserver) quoteKeyword(s string) string {
  36. parts := strings.Split(s, ".")
  37. for i, p := range parts {
  38. parts[i] = fmt.Sprintf(`[%s]`, p)
  39. }
  40. return strings.Join(parts, ".")
  41. }
  42. func (*sqlserver) databaseName(q queryable) (string, error) {
  43. var dbName string
  44. err := q.QueryRow("SELECT DB_NAME()").Scan(&dbName)
  45. return dbName, err
  46. }
  47. func (*sqlserver) tableNames(q queryable) ([]string, error) {
  48. rows, err := q.Query("SELECT table_schema + '.' + table_name FROM information_schema.tables WHERE table_name <> 'spt_values'")
  49. if err != nil {
  50. return nil, err
  51. }
  52. defer rows.Close()
  53. var tables []string
  54. for rows.Next() {
  55. var table string
  56. if err = rows.Scan(&table); err != nil {
  57. return nil, err
  58. }
  59. tables = append(tables, table)
  60. }
  61. if err = rows.Err(); err != nil {
  62. return nil, err
  63. }
  64. return tables, nil
  65. }
  66. func (h *sqlserver) tableHasIdentityColumn(q queryable, tableName string) (bool, error) {
  67. sql := fmt.Sprintf(`
  68. SELECT COUNT(*)
  69. FROM SYS.IDENTITY_COLUMNS
  70. WHERE OBJECT_ID = OBJECT_ID('%s')
  71. `, tableName)
  72. var count int
  73. if err := q.QueryRow(sql).Scan(&count); err != nil {
  74. return false, err
  75. }
  76. return count > 0, nil
  77. }
  78. func (h *sqlserver) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) (err error) {
  79. hasIdentityColumn, err := h.tableHasIdentityColumn(tx, tableName)
  80. if err != nil {
  81. return err
  82. }
  83. if hasIdentityColumn {
  84. defer func() {
  85. _, err2 := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
  86. if err2 != nil && err == nil {
  87. err = fmt.Errorf("testfixtures: could not disable identity insert: %w", err2)
  88. }
  89. }()
  90. _, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
  91. if err != nil {
  92. return fmt.Errorf("testfixtures: could not enable identity insert: %w", err)
  93. }
  94. }
  95. return fn()
  96. }
  97. func (h *sqlserver) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
  98. // ensure the triggers are re-enable after all
  99. defer func() {
  100. var sql string
  101. for _, table := range h.tables {
  102. sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table))
  103. }
  104. if _, err2 := db.Exec(sql); err2 != nil && err == nil {
  105. err = err2
  106. }
  107. }()
  108. var sql string
  109. for _, table := range h.tables {
  110. sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table))
  111. }
  112. if _, err := db.Exec(sql); err != nil {
  113. return err
  114. }
  115. tx, err := db.Begin()
  116. if err != nil {
  117. return err
  118. }
  119. defer tx.Rollback()
  120. if err = loadFn(tx); err != nil {
  121. return err
  122. }
  123. return tx.Commit()
  124. }
  125. // splitter is a batchSplitter interface implementation. We need it for
  126. // SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...`
  127. // could not be executed in the same batch.
  128. // See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches
  129. func (*sqlserver) splitter() []byte {
  130. return []byte("GO\n")
  131. }