You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dialect.go 9.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package dialects
  5. import (
  6. "context"
  7. "fmt"
  8. "strings"
  9. "time"
  10. "xorm.io/xorm/core"
  11. "xorm.io/xorm/schemas"
  12. )
  13. // URI represents an uri to visit database
  14. type URI struct {
  15. DBType schemas.DBType
  16. Proto string
  17. Host string
  18. Port string
  19. DBName string
  20. User string
  21. Passwd string
  22. Charset string
  23. Laddr string
  24. Raddr string
  25. Timeout time.Duration
  26. Schema string
  27. }
  28. // SetSchema set schema
  29. func (uri *URI) SetSchema(schema string) {
  30. // hack me
  31. if uri.DBType == schemas.POSTGRES {
  32. uri.Schema = strings.TrimSpace(schema)
  33. }
  34. }
  35. // Dialect represents a kind of database
  36. type Dialect interface {
  37. Init(*URI) error
  38. URI() *URI
  39. Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error)
  40. SQLType(*schemas.Column) string
  41. Alias(string) string // return what a sql type's alias of
  42. ColumnTypeKind(string) int // database column type kind
  43. IsReserved(string) bool
  44. Quoter() schemas.Quoter
  45. SetQuotePolicy(quotePolicy QuotePolicy)
  46. AutoIncrStr() string
  47. GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error)
  48. IndexCheckSQL(tableName, idxName string) (string, []interface{})
  49. CreateIndexSQL(tableName string, index *schemas.Index) string
  50. DropIndexSQL(tableName string, index *schemas.Index) string
  51. GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
  52. IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
  53. CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool)
  54. DropTableSQL(tableName string) (string, bool)
  55. GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
  56. IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
  57. AddColumnSQL(tableName string, col *schemas.Column) string
  58. ModifyColumnSQL(tableName string, col *schemas.Column) string
  59. ForUpdateSQL(query string) string
  60. Filters() []Filter
  61. SetParams(params map[string]string)
  62. }
  63. // Base represents a basic dialect and all real dialects could embed this struct
  64. type Base struct {
  65. dialect Dialect
  66. uri *URI
  67. quoter schemas.Quoter
  68. }
  69. // Alias returned col itself
  70. func (db *Base) Alias(col string) string {
  71. return col
  72. }
  73. // Quoter returns the current database Quoter
  74. func (db *Base) Quoter() schemas.Quoter {
  75. return db.quoter
  76. }
  77. // Init initialize the dialect
  78. func (db *Base) Init(dialect Dialect, uri *URI) error {
  79. db.dialect, db.uri = dialect, uri
  80. return nil
  81. }
  82. // URI returns the uri of database
  83. func (db *Base) URI() *URI {
  84. return db.uri
  85. }
  86. // CreateTableSQL implements Dialect
  87. func (db *Base) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
  88. if tableName == "" {
  89. tableName = table.Name
  90. }
  91. quoter := db.dialect.Quoter()
  92. var b strings.Builder
  93. b.WriteString("CREATE TABLE IF NOT EXISTS ")
  94. quoter.QuoteTo(&b, tableName)
  95. b.WriteString(" (")
  96. for i, colName := range table.ColumnsSeq() {
  97. col := table.GetColumn(colName)
  98. s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1)
  99. b.WriteString(s)
  100. if i != len(table.ColumnsSeq())-1 {
  101. b.WriteString(", ")
  102. }
  103. }
  104. if len(table.PrimaryKeys) > 1 {
  105. b.WriteString(", PRIMARY KEY (")
  106. b.WriteString(quoter.Join(table.PrimaryKeys, ","))
  107. b.WriteString(")")
  108. }
  109. b.WriteString(")")
  110. return []string{b.String()}, false
  111. }
  112. // DropTableSQL returns drop table SQL
  113. func (db *Base) DropTableSQL(tableName string) (string, bool) {
  114. quote := db.dialect.Quoter().Quote
  115. return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
  116. }
  117. // HasRecords returns true if the SQL has records returned
  118. func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) {
  119. rows, err := queryer.QueryContext(ctx, query, args...)
  120. if err != nil {
  121. return false, err
  122. }
  123. defer rows.Close()
  124. if rows.Next() {
  125. return true, nil
  126. }
  127. return false, rows.Err()
  128. }
  129. // IsColumnExist returns true if the column of the table exist
  130. func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
  131. quote := db.dialect.Quoter().Quote
  132. query := fmt.Sprintf(
  133. "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
  134. quote("COLUMN_NAME"),
  135. quote("INFORMATION_SCHEMA"),
  136. quote("COLUMNS"),
  137. quote("TABLE_SCHEMA"),
  138. quote("TABLE_NAME"),
  139. quote("COLUMN_NAME"),
  140. )
  141. return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName)
  142. }
  143. // AddColumnSQL returns a SQL to add a column
  144. func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
  145. s, _ := ColumnString(db.dialect, col, true)
  146. return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s)
  147. }
  148. // CreateIndexSQL returns a SQL to create index
  149. func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
  150. quoter := db.dialect.Quoter()
  151. var unique string
  152. var idxName string
  153. if index.Type == schemas.UniqueType {
  154. unique = " UNIQUE"
  155. }
  156. idxName = index.XName(tableName)
  157. return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
  158. quoter.Quote(idxName), quoter.Quote(tableName),
  159. quoter.Join(index.Cols, ","))
  160. }
  161. // DropIndexSQL returns a SQL to drop index
  162. func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
  163. quote := db.dialect.Quoter().Quote
  164. var name string
  165. if index.IsRegular {
  166. name = index.XName(tableName)
  167. } else {
  168. name = index.Name
  169. }
  170. return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
  171. }
  172. // ModifyColumnSQL returns a SQL to modify SQL
  173. func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
  174. s, _ := ColumnString(db.dialect, col, false)
  175. return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", tableName, s)
  176. }
  177. // ForUpdateSQL returns for updateSQL
  178. func (db *Base) ForUpdateSQL(query string) string {
  179. return query + " FOR UPDATE"
  180. }
  181. // SetParams set params
  182. func (db *Base) SetParams(params map[string]string) {
  183. }
  184. var (
  185. dialects = map[string]func() Dialect{}
  186. )
  187. // RegisterDialect register database dialect
  188. func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
  189. if dialectFunc == nil {
  190. panic("core: Register dialect is nil")
  191. }
  192. dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
  193. }
  194. // QueryDialect query if registered database dialect
  195. func QueryDialect(dbName schemas.DBType) Dialect {
  196. if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
  197. return d()
  198. }
  199. return nil
  200. }
  201. func regDrvsNDialects() bool {
  202. providedDrvsNDialects := map[string]struct {
  203. dbType schemas.DBType
  204. getDriver func() Driver
  205. getDialect func() Dialect
  206. }{
  207. "mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }},
  208. "odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
  209. "mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }},
  210. "mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }},
  211. "postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
  212. "pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
  213. "sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
  214. "sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
  215. "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
  216. "godror": {"oracle", func() Driver { return &godrorDriver{} }, func() Dialect { return &oracle{} }},
  217. }
  218. for driverName, v := range providedDrvsNDialects {
  219. if driver := QueryDriver(driverName); driver == nil {
  220. RegisterDriver(driverName, v.getDriver())
  221. RegisterDialect(v.dbType, v.getDialect)
  222. }
  223. }
  224. return true
  225. }
  226. func init() {
  227. regDrvsNDialects()
  228. }
  229. // ColumnString generate column description string according dialect
  230. func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) {
  231. bd := strings.Builder{}
  232. if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil {
  233. return "", err
  234. }
  235. if err := bd.WriteByte(' '); err != nil {
  236. return "", err
  237. }
  238. if _, err := bd.WriteString(dialect.SQLType(col)); err != nil {
  239. return "", err
  240. }
  241. if err := bd.WriteByte(' '); err != nil {
  242. return "", err
  243. }
  244. if includePrimaryKey && col.IsPrimaryKey {
  245. if _, err := bd.WriteString("PRIMARY KEY "); err != nil {
  246. return "", err
  247. }
  248. if col.IsAutoIncrement {
  249. if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil {
  250. return "", err
  251. }
  252. if err := bd.WriteByte(' '); err != nil {
  253. return "", err
  254. }
  255. }
  256. }
  257. if col.Default != "" {
  258. if _, err := bd.WriteString("DEFAULT "); err != nil {
  259. return "", err
  260. }
  261. if _, err := bd.WriteString(col.Default); err != nil {
  262. return "", err
  263. }
  264. if err := bd.WriteByte(' '); err != nil {
  265. return "", err
  266. }
  267. }
  268. if col.Nullable {
  269. if _, err := bd.WriteString("NULL "); err != nil {
  270. return "", err
  271. }
  272. } else {
  273. if _, err := bd.WriteString("NOT NULL "); err != nil {
  274. return "", err
  275. }
  276. }
  277. return bd.String(), nil
  278. }