You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mysql.go 2.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. )
  6. type mySQL struct {
  7. baseHelper
  8. tables []string
  9. tablesChecksum map[string]int64
  10. }
  11. func (h *mySQL) init(db *sql.DB) error {
  12. var err error
  13. h.tables, err = h.tableNames(db)
  14. if err != nil {
  15. return err
  16. }
  17. return nil
  18. }
  19. func (*mySQL) paramType() int {
  20. return paramTypeQuestion
  21. }
  22. func (*mySQL) quoteKeyword(str string) string {
  23. return fmt.Sprintf("`%s`", str)
  24. }
  25. func (*mySQL) databaseName(q queryable) (string, error) {
  26. var dbName string
  27. err := q.QueryRow("SELECT DATABASE()").Scan(&dbName)
  28. return dbName, err
  29. }
  30. func (h *mySQL) tableNames(q queryable) ([]string, error) {
  31. query := `
  32. SELECT table_name
  33. FROM information_schema.tables
  34. WHERE table_schema = ?
  35. AND table_type = 'BASE TABLE';
  36. `
  37. dbName, err := h.databaseName(q)
  38. if err != nil {
  39. return nil, err
  40. }
  41. rows, err := q.Query(query, dbName)
  42. if err != nil {
  43. return nil, err
  44. }
  45. defer rows.Close()
  46. var tables []string
  47. for rows.Next() {
  48. var table string
  49. if err = rows.Scan(&table); err != nil {
  50. return nil, err
  51. }
  52. tables = append(tables, table)
  53. }
  54. if err = rows.Err(); err != nil {
  55. return nil, err
  56. }
  57. return tables, nil
  58. }
  59. func (h *mySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
  60. tx, err := db.Begin()
  61. if err != nil {
  62. return err
  63. }
  64. defer tx.Rollback()
  65. if _, err = tx.Exec("SET FOREIGN_KEY_CHECKS = 0"); err != nil {
  66. return err
  67. }
  68. err = loadFn(tx)
  69. _, err2 := tx.Exec("SET FOREIGN_KEY_CHECKS = 1")
  70. if err != nil {
  71. return err
  72. }
  73. if err2 != nil {
  74. return err2
  75. }
  76. return tx.Commit()
  77. }
  78. func (h *mySQL) isTableModified(q queryable, tableName string) (bool, error) {
  79. checksum, err := h.getChecksum(q, tableName)
  80. if err != nil {
  81. return true, err
  82. }
  83. oldChecksum := h.tablesChecksum[tableName]
  84. return oldChecksum == 0 || checksum != oldChecksum, nil
  85. }
  86. func (h *mySQL) afterLoad(q queryable) error {
  87. if h.tablesChecksum != nil {
  88. return nil
  89. }
  90. h.tablesChecksum = make(map[string]int64, len(h.tables))
  91. for _, t := range h.tables {
  92. checksum, err := h.getChecksum(q, t)
  93. if err != nil {
  94. return err
  95. }
  96. h.tablesChecksum[t] = checksum
  97. }
  98. return nil
  99. }
  100. func (h *mySQL) getChecksum(q queryable, tableName string) (int64, error) {
  101. sql := fmt.Sprintf("CHECKSUM TABLE %s", h.quoteKeyword(tableName))
  102. var (
  103. table string
  104. checksum int64
  105. )
  106. if err := q.QueryRow(sql).Scan(&table, &checksum); err != nil {
  107. return 0, err
  108. }
  109. return checksum, nil
  110. }