You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

db.go 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package core
  5. import (
  6. "context"
  7. "database/sql"
  8. "database/sql/driver"
  9. "fmt"
  10. "reflect"
  11. "regexp"
  12. "sync"
  13. "time"
  14. "xorm.io/xorm/log"
  15. "xorm.io/xorm/names"
  16. )
  17. var (
  18. // DefaultCacheSize sets the default cache size
  19. DefaultCacheSize = 200
  20. )
  21. func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
  22. vv := reflect.ValueOf(mp)
  23. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  24. return "", []interface{}{}, ErrNoMapPointer
  25. }
  26. args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
  27. var err error
  28. query = re.ReplaceAllStringFunc(query, func(src string) string {
  29. v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
  30. if !v.IsValid() {
  31. err = fmt.Errorf("map key %s is missing", src[1:])
  32. } else {
  33. args = append(args, v.Interface())
  34. }
  35. return "?"
  36. })
  37. return query, args, err
  38. }
  39. func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
  40. vv := reflect.ValueOf(st)
  41. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  42. return "", []interface{}{}, ErrNoStructPointer
  43. }
  44. args := make([]interface{}, 0)
  45. var err error
  46. query = re.ReplaceAllStringFunc(query, func(src string) string {
  47. fv := vv.Elem().FieldByName(src[1:]).Interface()
  48. if v, ok := fv.(driver.Valuer); ok {
  49. var value driver.Value
  50. value, err = v.Value()
  51. if err != nil {
  52. return "?"
  53. }
  54. args = append(args, value)
  55. } else {
  56. args = append(args, fv)
  57. }
  58. return "?"
  59. })
  60. if err != nil {
  61. return "", []interface{}{}, err
  62. }
  63. return query, args, nil
  64. }
  65. type cacheStruct struct {
  66. value reflect.Value
  67. idx int
  68. }
  69. var (
  70. _ QueryExecuter = &DB{}
  71. )
  72. // DB is a wrap of sql.DB with extra contents
  73. type DB struct {
  74. *sql.DB
  75. Mapper names.Mapper
  76. reflectCache map[reflect.Type]*cacheStruct
  77. reflectCacheMutex sync.RWMutex
  78. Logger log.ContextLogger
  79. }
  80. // Open opens a database
  81. func Open(driverName, dataSourceName string) (*DB, error) {
  82. db, err := sql.Open(driverName, dataSourceName)
  83. if err != nil {
  84. return nil, err
  85. }
  86. return &DB{
  87. DB: db,
  88. Mapper: names.NewCacheMapper(&names.SnakeMapper{}),
  89. reflectCache: make(map[reflect.Type]*cacheStruct),
  90. }, nil
  91. }
  92. // FromDB creates a DB from a sql.DB
  93. func FromDB(db *sql.DB) *DB {
  94. return &DB{
  95. DB: db,
  96. Mapper: names.NewCacheMapper(&names.SnakeMapper{}),
  97. reflectCache: make(map[reflect.Type]*cacheStruct),
  98. }
  99. }
  100. // NeedLogSQL returns true if need to log SQL
  101. func (db *DB) NeedLogSQL(ctx context.Context) bool {
  102. if db.Logger == nil {
  103. return false
  104. }
  105. v := ctx.Value("__xorm_show_sql")
  106. if showSQL, ok := v.(bool); ok {
  107. return showSQL
  108. }
  109. return db.Logger.IsShowSQL()
  110. }
  111. func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
  112. db.reflectCacheMutex.Lock()
  113. defer db.reflectCacheMutex.Unlock()
  114. cs, ok := db.reflectCache[typ]
  115. if !ok || cs.idx+1 > DefaultCacheSize-1 {
  116. cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0}
  117. db.reflectCache[typ] = cs
  118. } else {
  119. cs.idx = cs.idx + 1
  120. }
  121. return cs.value.Index(cs.idx).Addr()
  122. }
  123. // QueryContext overwrites sql.DB.QueryContext
  124. func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
  125. start := time.Now()
  126. showSQL := db.NeedLogSQL(ctx)
  127. if showSQL {
  128. db.Logger.BeforeSQL(log.LogContext{
  129. Ctx: ctx,
  130. SQL: query,
  131. Args: args,
  132. })
  133. }
  134. rows, err := db.DB.QueryContext(ctx, query, args...)
  135. if showSQL {
  136. db.Logger.AfterSQL(log.LogContext{
  137. Ctx: ctx,
  138. SQL: query,
  139. Args: args,
  140. ExecuteTime: time.Now().Sub(start),
  141. Err: err,
  142. })
  143. }
  144. if err != nil {
  145. if rows != nil {
  146. rows.Close()
  147. }
  148. return nil, err
  149. }
  150. return &Rows{rows, db}, nil
  151. }
  152. // Query overwrites sql.DB.Query
  153. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
  154. return db.QueryContext(context.Background(), query, args...)
  155. }
  156. // QueryMapContext executes query with parameters via map and context
  157. func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
  158. query, args, err := MapToSlice(query, mp)
  159. if err != nil {
  160. return nil, err
  161. }
  162. return db.QueryContext(ctx, query, args...)
  163. }
  164. // QueryMap executes query with parameters via map
  165. func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
  166. return db.QueryMapContext(context.Background(), query, mp)
  167. }
  168. func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
  169. query, args, err := StructToSlice(query, st)
  170. if err != nil {
  171. return nil, err
  172. }
  173. return db.QueryContext(ctx, query, args...)
  174. }
  175. func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
  176. return db.QueryStructContext(context.Background(), query, st)
  177. }
  178. func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
  179. rows, err := db.QueryContext(ctx, query, args...)
  180. if err != nil {
  181. return &Row{nil, err}
  182. }
  183. return &Row{rows, nil}
  184. }
  185. func (db *DB) QueryRow(query string, args ...interface{}) *Row {
  186. return db.QueryRowContext(context.Background(), query, args...)
  187. }
  188. func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
  189. query, args, err := MapToSlice(query, mp)
  190. if err != nil {
  191. return &Row{nil, err}
  192. }
  193. return db.QueryRowContext(ctx, query, args...)
  194. }
  195. func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
  196. return db.QueryRowMapContext(context.Background(), query, mp)
  197. }
  198. func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
  199. query, args, err := StructToSlice(query, st)
  200. if err != nil {
  201. return &Row{nil, err}
  202. }
  203. return db.QueryRowContext(ctx, query, args...)
  204. }
  205. func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
  206. return db.QueryRowStructContext(context.Background(), query, st)
  207. }
  208. var (
  209. re = regexp.MustCompile(`[?](\w+)`)
  210. )
  211. // ExecMapContext exec map with context.Context
  212. // insert into (name) values (?)
  213. // insert into (name) values (?name)
  214. func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
  215. query, args, err := MapToSlice(query, mp)
  216. if err != nil {
  217. return nil, err
  218. }
  219. return db.ExecContext(ctx, query, args...)
  220. }
  221. func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
  222. return db.ExecMapContext(context.Background(), query, mp)
  223. }
  224. func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
  225. query, args, err := StructToSlice(query, st)
  226. if err != nil {
  227. return nil, err
  228. }
  229. return db.ExecContext(ctx, query, args...)
  230. }
  231. func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
  232. start := time.Now()
  233. showSQL := db.NeedLogSQL(ctx)
  234. if showSQL {
  235. db.Logger.BeforeSQL(log.LogContext{
  236. Ctx: ctx,
  237. SQL: query,
  238. Args: args,
  239. })
  240. }
  241. res, err := db.DB.ExecContext(ctx, query, args...)
  242. if showSQL {
  243. db.Logger.AfterSQL(log.LogContext{
  244. Ctx: ctx,
  245. SQL: query,
  246. Args: args,
  247. ExecuteTime: time.Now().Sub(start),
  248. Err: err,
  249. })
  250. }
  251. return res, err
  252. }
  253. func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
  254. return db.ExecStructContext(context.Background(), query, st)
  255. }