You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dialect.go 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package dialects
  5. import (
  6. "context"
  7. "fmt"
  8. "strings"
  9. "time"
  10. "xorm.io/xorm/core"
  11. "xorm.io/xorm/schemas"
  12. )
  13. // URI represents an uri to visit database
  14. type URI struct {
  15. DBType schemas.DBType
  16. Proto string
  17. Host string
  18. Port string
  19. DBName string
  20. User string
  21. Passwd string
  22. Charset string
  23. Laddr string
  24. Raddr string
  25. Timeout time.Duration
  26. Schema string
  27. }
  28. // SetSchema set schema
  29. func (uri *URI) SetSchema(schema string) {
  30. // hack me
  31. if uri.DBType == schemas.POSTGRES {
  32. uri.Schema = strings.TrimSpace(schema)
  33. }
  34. }
  35. // Dialect represents a kind of database
  36. type Dialect interface {
  37. Init(*URI) error
  38. URI() *URI
  39. SQLType(*schemas.Column) string
  40. FormatBytes(b []byte) string
  41. IsReserved(string) bool
  42. Quoter() schemas.Quoter
  43. SetQuotePolicy(quotePolicy QuotePolicy)
  44. AutoIncrStr() string
  45. GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error)
  46. IndexCheckSQL(tableName, idxName string) (string, []interface{})
  47. CreateIndexSQL(tableName string, index *schemas.Index) string
  48. DropIndexSQL(tableName string, index *schemas.Index) string
  49. GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
  50. IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
  51. CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool)
  52. DropTableSQL(tableName string) (string, bool)
  53. GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
  54. IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
  55. AddColumnSQL(tableName string, col *schemas.Column) string
  56. ModifyColumnSQL(tableName string, col *schemas.Column) string
  57. ForUpdateSQL(query string) string
  58. Filters() []Filter
  59. SetParams(params map[string]string)
  60. }
  61. // Base represents a basic dialect and all real dialects could embed this struct
  62. type Base struct {
  63. dialect Dialect
  64. uri *URI
  65. quoter schemas.Quoter
  66. }
  67. func (b *Base) Quoter() schemas.Quoter {
  68. return b.quoter
  69. }
  70. func (b *Base) Init(dialect Dialect, uri *URI) error {
  71. b.dialect, b.uri = dialect, uri
  72. return nil
  73. }
  74. func (b *Base) URI() *URI {
  75. return b.uri
  76. }
  77. func (b *Base) DBType() schemas.DBType {
  78. return b.uri.DBType
  79. }
  80. // String generate column description string according dialect
  81. func (b *Base) String(col *schemas.Column) string {
  82. sql := b.dialect.Quoter().Quote(col.Name) + " "
  83. sql += b.dialect.SQLType(col) + " "
  84. if col.IsPrimaryKey {
  85. sql += "PRIMARY KEY "
  86. if col.IsAutoIncrement {
  87. sql += b.dialect.AutoIncrStr() + " "
  88. }
  89. }
  90. if col.Default != "" {
  91. sql += "DEFAULT " + col.Default + " "
  92. }
  93. if col.Nullable {
  94. sql += "NULL "
  95. } else {
  96. sql += "NOT NULL "
  97. }
  98. return sql
  99. }
  100. // StringNoPk generate column description string according dialect without primary keys
  101. func (b *Base) StringNoPk(col *schemas.Column) string {
  102. sql := b.dialect.Quoter().Quote(col.Name) + " "
  103. sql += b.dialect.SQLType(col) + " "
  104. if col.Default != "" {
  105. sql += "DEFAULT " + col.Default + " "
  106. }
  107. if col.Nullable {
  108. sql += "NULL "
  109. } else {
  110. sql += "NOT NULL "
  111. }
  112. return sql
  113. }
  114. func (b *Base) FormatBytes(bs []byte) string {
  115. return fmt.Sprintf("0x%x", bs)
  116. }
  117. func (db *Base) DropTableSQL(tableName string) (string, bool) {
  118. quote := db.dialect.Quoter().Quote
  119. return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
  120. }
  121. func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) {
  122. rows, err := queryer.QueryContext(ctx, query, args...)
  123. if err != nil {
  124. return false, err
  125. }
  126. defer rows.Close()
  127. if rows.Next() {
  128. return true, nil
  129. }
  130. return false, nil
  131. }
  132. func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
  133. quote := db.dialect.Quoter().Quote
  134. query := fmt.Sprintf(
  135. "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
  136. quote("COLUMN_NAME"),
  137. quote("INFORMATION_SCHEMA"),
  138. quote("COLUMNS"),
  139. quote("TABLE_SCHEMA"),
  140. quote("TABLE_NAME"),
  141. quote("COLUMN_NAME"),
  142. )
  143. return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName)
  144. }
  145. func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
  146. return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName),
  147. db.String(col))
  148. }
  149. func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
  150. quoter := db.dialect.Quoter()
  151. var unique string
  152. var idxName string
  153. if index.Type == schemas.UniqueType {
  154. unique = " UNIQUE"
  155. }
  156. idxName = index.XName(tableName)
  157. return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
  158. quoter.Quote(idxName), quoter.Quote(tableName),
  159. quoter.Join(index.Cols, ","))
  160. }
  161. func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
  162. quote := db.dialect.Quoter().Quote
  163. var name string
  164. if index.IsRegular {
  165. name = index.XName(tableName)
  166. } else {
  167. name = index.Name
  168. }
  169. return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
  170. }
  171. func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
  172. return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col))
  173. }
  174. func (b *Base) ForUpdateSQL(query string) string {
  175. return query + " FOR UPDATE"
  176. }
  177. func (b *Base) SetParams(params map[string]string) {
  178. }
  179. var (
  180. dialects = map[string]func() Dialect{}
  181. )
  182. // RegisterDialect register database dialect
  183. func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
  184. if dialectFunc == nil {
  185. panic("core: Register dialect is nil")
  186. }
  187. dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
  188. }
  189. // QueryDialect query if registered database dialect
  190. func QueryDialect(dbName schemas.DBType) Dialect {
  191. if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
  192. return d()
  193. }
  194. return nil
  195. }
  196. func regDrvsNDialects() bool {
  197. providedDrvsNDialects := map[string]struct {
  198. dbType schemas.DBType
  199. getDriver func() Driver
  200. getDialect func() Dialect
  201. }{
  202. "mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }},
  203. "odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
  204. "mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }},
  205. "mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }},
  206. "postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
  207. "pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
  208. "sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
  209. "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
  210. "goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }},
  211. }
  212. for driverName, v := range providedDrvsNDialects {
  213. if driver := QueryDriver(driverName); driver == nil {
  214. RegisterDriver(driverName, v.getDriver())
  215. RegisterDialect(v.dbType, v.getDialect)
  216. }
  217. }
  218. return true
  219. }
  220. func init() {
  221. regDrvsNDialects()
  222. }