You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

stmt.go 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package core
  5. import (
  6. "context"
  7. "database/sql"
  8. "errors"
  9. "reflect"
  10. "time"
  11. "xorm.io/xorm/log"
  12. )
  13. // Stmt reprents a stmt objects
  14. type Stmt struct {
  15. *sql.Stmt
  16. db *DB
  17. names map[string]int
  18. query string
  19. }
  20. func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
  21. names := make(map[string]int)
  22. var i int
  23. query = re.ReplaceAllStringFunc(query, func(src string) string {
  24. names[src[1:]] = i
  25. i++
  26. return "?"
  27. })
  28. start := time.Now()
  29. showSQL := db.NeedLogSQL(ctx)
  30. if showSQL {
  31. db.Logger.BeforeSQL(log.LogContext{
  32. Ctx: ctx,
  33. SQL: "PREPARE",
  34. })
  35. }
  36. stmt, err := db.DB.PrepareContext(ctx, query)
  37. if showSQL {
  38. db.Logger.AfterSQL(log.LogContext{
  39. Ctx: ctx,
  40. SQL: "PREPARE",
  41. ExecuteTime: time.Now().Sub(start),
  42. Err: err,
  43. })
  44. }
  45. if err != nil {
  46. return nil, err
  47. }
  48. return &Stmt{stmt, db, names, query}, nil
  49. }
  50. func (db *DB) Prepare(query string) (*Stmt, error) {
  51. return db.PrepareContext(context.Background(), query)
  52. }
  53. func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) {
  54. vv := reflect.ValueOf(mp)
  55. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  56. return nil, errors.New("mp should be a map's pointer")
  57. }
  58. args := make([]interface{}, len(s.names))
  59. for k, i := range s.names {
  60. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  61. }
  62. return s.ExecContext(ctx, args...)
  63. }
  64. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  65. return s.ExecMapContext(context.Background(), mp)
  66. }
  67. func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) {
  68. vv := reflect.ValueOf(st)
  69. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  70. return nil, errors.New("mp should be a map's pointer")
  71. }
  72. args := make([]interface{}, len(s.names))
  73. for k, i := range s.names {
  74. args[i] = vv.Elem().FieldByName(k).Interface()
  75. }
  76. return s.ExecContext(ctx, args...)
  77. }
  78. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  79. return s.ExecStructContext(context.Background(), st)
  80. }
  81. func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
  82. start := time.Now()
  83. showSQL := s.db.NeedLogSQL(ctx)
  84. if showSQL {
  85. s.db.Logger.BeforeSQL(log.LogContext{
  86. Ctx: ctx,
  87. SQL: s.query,
  88. Args: args,
  89. })
  90. }
  91. res, err := s.Stmt.ExecContext(ctx, args)
  92. if showSQL {
  93. s.db.Logger.AfterSQL(log.LogContext{
  94. Ctx: ctx,
  95. SQL: s.query,
  96. Args: args,
  97. ExecuteTime: time.Now().Sub(start),
  98. Err: err,
  99. })
  100. }
  101. return res, err
  102. }
  103. func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
  104. start := time.Now()
  105. showSQL := s.db.NeedLogSQL(ctx)
  106. if showSQL {
  107. s.db.Logger.BeforeSQL(log.LogContext{
  108. Ctx: ctx,
  109. SQL: s.query,
  110. Args: args,
  111. })
  112. }
  113. rows, err := s.Stmt.QueryContext(ctx, args...)
  114. if showSQL {
  115. s.db.Logger.AfterSQL(log.LogContext{
  116. Ctx: ctx,
  117. SQL: s.query,
  118. Args: args,
  119. ExecuteTime: time.Now().Sub(start),
  120. Err: err,
  121. })
  122. }
  123. if err != nil {
  124. return nil, err
  125. }
  126. return &Rows{rows, s.db}, nil
  127. }
  128. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  129. return s.QueryContext(context.Background(), args...)
  130. }
  131. func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) {
  132. vv := reflect.ValueOf(mp)
  133. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  134. return nil, errors.New("mp should be a map's pointer")
  135. }
  136. args := make([]interface{}, len(s.names))
  137. for k, i := range s.names {
  138. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  139. }
  140. return s.QueryContext(ctx, args...)
  141. }
  142. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  143. return s.QueryMapContext(context.Background(), mp)
  144. }
  145. func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) {
  146. vv := reflect.ValueOf(st)
  147. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  148. return nil, errors.New("mp should be a map's pointer")
  149. }
  150. args := make([]interface{}, len(s.names))
  151. for k, i := range s.names {
  152. args[i] = vv.Elem().FieldByName(k).Interface()
  153. }
  154. return s.Query(args...)
  155. }
  156. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  157. return s.QueryStructContext(context.Background(), st)
  158. }
  159. func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
  160. rows, err := s.QueryContext(ctx, args...)
  161. return &Row{rows, err}
  162. }
  163. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  164. return s.QueryRowContext(context.Background(), args...)
  165. }
  166. func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row {
  167. vv := reflect.ValueOf(mp)
  168. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  169. return &Row{nil, errors.New("mp should be a map's pointer")}
  170. }
  171. args := make([]interface{}, len(s.names))
  172. for k, i := range s.names {
  173. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  174. }
  175. return s.QueryRowContext(ctx, args...)
  176. }
  177. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  178. return s.QueryRowMapContext(context.Background(), mp)
  179. }
  180. func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row {
  181. vv := reflect.ValueOf(st)
  182. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  183. return &Row{nil, errors.New("st should be a struct's pointer")}
  184. }
  185. args := make([]interface{}, len(s.names))
  186. for k, i := range s.names {
  187. args[i] = vv.Elem().FieldByName(k).Interface()
  188. }
  189. return s.QueryRowContext(ctx, args...)
  190. }
  191. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  192. return s.QueryRowStructContext(context.Background(), st)
  193. }