You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

db.go 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package core
  5. import (
  6. "context"
  7. "database/sql"
  8. "database/sql/driver"
  9. "fmt"
  10. "reflect"
  11. "regexp"
  12. "sync"
  13. )
  14. var (
  15. DefaultCacheSize = 200
  16. )
  17. func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
  18. vv := reflect.ValueOf(mp)
  19. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  20. return "", []interface{}{}, ErrNoMapPointer
  21. }
  22. args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
  23. var err error
  24. query = re.ReplaceAllStringFunc(query, func(src string) string {
  25. v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
  26. if !v.IsValid() {
  27. err = fmt.Errorf("map key %s is missing", src[1:])
  28. } else {
  29. args = append(args, v.Interface())
  30. }
  31. return "?"
  32. })
  33. return query, args, err
  34. }
  35. func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
  36. vv := reflect.ValueOf(st)
  37. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  38. return "", []interface{}{}, ErrNoStructPointer
  39. }
  40. args := make([]interface{}, 0)
  41. var err error
  42. query = re.ReplaceAllStringFunc(query, func(src string) string {
  43. fv := vv.Elem().FieldByName(src[1:]).Interface()
  44. if v, ok := fv.(driver.Valuer); ok {
  45. var value driver.Value
  46. value, err = v.Value()
  47. if err != nil {
  48. return "?"
  49. }
  50. args = append(args, value)
  51. } else {
  52. args = append(args, fv)
  53. }
  54. return "?"
  55. })
  56. if err != nil {
  57. return "", []interface{}{}, err
  58. }
  59. return query, args, nil
  60. }
  61. type cacheStruct struct {
  62. value reflect.Value
  63. idx int
  64. }
  65. // DB is a wrap of sql.DB with extra contents
  66. type DB struct {
  67. *sql.DB
  68. Mapper IMapper
  69. reflectCache map[reflect.Type]*cacheStruct
  70. reflectCacheMutex sync.RWMutex
  71. }
  72. // Open opens a database
  73. func Open(driverName, dataSourceName string) (*DB, error) {
  74. db, err := sql.Open(driverName, dataSourceName)
  75. if err != nil {
  76. return nil, err
  77. }
  78. return &DB{
  79. DB: db,
  80. Mapper: NewCacheMapper(&SnakeMapper{}),
  81. reflectCache: make(map[reflect.Type]*cacheStruct),
  82. }, nil
  83. }
  84. // FromDB creates a DB from a sql.DB
  85. func FromDB(db *sql.DB) *DB {
  86. return &DB{
  87. DB: db,
  88. Mapper: NewCacheMapper(&SnakeMapper{}),
  89. reflectCache: make(map[reflect.Type]*cacheStruct),
  90. }
  91. }
  92. func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
  93. db.reflectCacheMutex.Lock()
  94. defer db.reflectCacheMutex.Unlock()
  95. cs, ok := db.reflectCache[typ]
  96. if !ok || cs.idx+1 > DefaultCacheSize-1 {
  97. cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0}
  98. db.reflectCache[typ] = cs
  99. } else {
  100. cs.idx = cs.idx + 1
  101. }
  102. return cs.value.Index(cs.idx).Addr()
  103. }
  104. // QueryContext overwrites sql.DB.QueryContext
  105. func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
  106. rows, err := db.DB.QueryContext(ctx, query, args...)
  107. if err != nil {
  108. if rows != nil {
  109. rows.Close()
  110. }
  111. return nil, err
  112. }
  113. return &Rows{rows, db}, nil
  114. }
  115. // Query overwrites sql.DB.Query
  116. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
  117. return db.QueryContext(context.Background(), query, args...)
  118. }
  119. func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
  120. query, args, err := MapToSlice(query, mp)
  121. if err != nil {
  122. return nil, err
  123. }
  124. return db.QueryContext(ctx, query, args...)
  125. }
  126. func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
  127. return db.QueryMapContext(context.Background(), query, mp)
  128. }
  129. func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
  130. query, args, err := StructToSlice(query, st)
  131. if err != nil {
  132. return nil, err
  133. }
  134. return db.QueryContext(ctx, query, args...)
  135. }
  136. func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
  137. return db.QueryStructContext(context.Background(), query, st)
  138. }
  139. func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
  140. rows, err := db.QueryContext(ctx, query, args...)
  141. if err != nil {
  142. return &Row{nil, err}
  143. }
  144. return &Row{rows, nil}
  145. }
  146. func (db *DB) QueryRow(query string, args ...interface{}) *Row {
  147. return db.QueryRowContext(context.Background(), query, args...)
  148. }
  149. func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
  150. query, args, err := MapToSlice(query, mp)
  151. if err != nil {
  152. return &Row{nil, err}
  153. }
  154. return db.QueryRowContext(ctx, query, args...)
  155. }
  156. func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
  157. return db.QueryRowMapContext(context.Background(), query, mp)
  158. }
  159. func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
  160. query, args, err := StructToSlice(query, st)
  161. if err != nil {
  162. return &Row{nil, err}
  163. }
  164. return db.QueryRowContext(ctx, query, args...)
  165. }
  166. func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
  167. return db.QueryRowStructContext(context.Background(), query, st)
  168. }
  169. var (
  170. re = regexp.MustCompile(`[?](\w+)`)
  171. )
  172. // insert into (name) values (?)
  173. // insert into (name) values (?name)
  174. func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
  175. query, args, err := MapToSlice(query, mp)
  176. if err != nil {
  177. return nil, err
  178. }
  179. return db.DB.ExecContext(ctx, query, args...)
  180. }
  181. func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
  182. return db.ExecMapContext(context.Background(), query, mp)
  183. }
  184. func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
  185. query, args, err := StructToSlice(query, st)
  186. if err != nil {
  187. return nil, err
  188. }
  189. return db.DB.ExecContext(ctx, query, args...)
  190. }
  191. func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
  192. return db.ExecStructContext(context.Background(), query, st)
  193. }