You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_get.go 9.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql"
  7. "database/sql/driver"
  8. "errors"
  9. "fmt"
  10. "math/big"
  11. "reflect"
  12. "strconv"
  13. "time"
  14. "xorm.io/xorm/caches"
  15. "xorm.io/xorm/convert"
  16. "xorm.io/xorm/core"
  17. "xorm.io/xorm/internal/utils"
  18. "xorm.io/xorm/schemas"
  19. )
  20. var (
  21. // ErrObjectIsNil return error of object is nil
  22. ErrObjectIsNil = errors.New("object should not be nil")
  23. )
  24. // Get retrieve one record from database, bean's non-empty fields
  25. // will be as conditions
  26. func (session *Session) Get(bean interface{}) (bool, error) {
  27. if session.isAutoClose {
  28. defer session.Close()
  29. }
  30. return session.get(bean)
  31. }
  32. func isPtrOfTime(v interface{}) bool {
  33. if _, ok := v.(*time.Time); ok {
  34. return true
  35. }
  36. el := reflect.ValueOf(v).Elem()
  37. if el.Kind() != reflect.Struct {
  38. return false
  39. }
  40. return el.Type().ConvertibleTo(schemas.TimeType)
  41. }
  42. func (session *Session) get(bean interface{}) (bool, error) {
  43. defer session.resetStatement()
  44. if session.statement.LastError != nil {
  45. return false, session.statement.LastError
  46. }
  47. beanValue := reflect.ValueOf(bean)
  48. if beanValue.Kind() != reflect.Ptr {
  49. return false, errors.New("needs a pointer to a value")
  50. } else if beanValue.Elem().Kind() == reflect.Ptr {
  51. return false, errors.New("a pointer to a pointer is not allowed")
  52. } else if beanValue.IsNil() {
  53. return false, ErrObjectIsNil
  54. }
  55. if beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(bean) {
  56. if err := session.statement.SetRefBean(bean); err != nil {
  57. return false, err
  58. }
  59. }
  60. var sqlStr string
  61. var args []interface{}
  62. var err error
  63. if session.statement.RawSQL == "" {
  64. if len(session.statement.TableName()) <= 0 {
  65. return false, ErrTableNotFound
  66. }
  67. session.statement.Limit(1)
  68. sqlStr, args, err = session.statement.GenGetSQL(bean)
  69. if err != nil {
  70. return false, err
  71. }
  72. } else {
  73. sqlStr = session.statement.GenRawSQL()
  74. args = session.statement.RawParams
  75. }
  76. table := session.statement.RefTable
  77. if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
  78. if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil &&
  79. !session.statement.GetUnscoped() {
  80. has, err := session.cacheGet(bean, sqlStr, args...)
  81. if err != ErrCacheFailed {
  82. return has, err
  83. }
  84. }
  85. }
  86. context := session.statement.Context
  87. if context != nil {
  88. res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
  89. if res != nil {
  90. session.engine.logger.Debugf("hit context cache: %s", sqlStr)
  91. structValue := reflect.Indirect(reflect.ValueOf(bean))
  92. structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
  93. session.lastSQL = ""
  94. session.lastSQLArgs = nil
  95. return true, nil
  96. }
  97. }
  98. has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
  99. if err != nil || !has {
  100. return has, err
  101. }
  102. if context != nil {
  103. context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean)
  104. }
  105. return true, nil
  106. }
  107. var (
  108. valuerTypePlaceHolder driver.Valuer
  109. valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem()
  110. scannerTypePlaceHolder sql.Scanner
  111. scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem()
  112. conversionTypePlaceHolder convert.Conversion
  113. conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
  114. )
  115. func isScannableStruct(bean interface{}, typeLen int) bool {
  116. switch bean.(type) {
  117. case *time.Time:
  118. return false
  119. case sql.Scanner:
  120. return false
  121. case convert.Conversion:
  122. return typeLen > 1
  123. case *big.Float:
  124. return false
  125. }
  126. return true
  127. }
  128. func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
  129. rows, err := session.queryRows(sqlStr, args...)
  130. if err != nil {
  131. return false, err
  132. }
  133. defer rows.Close()
  134. if !rows.Next() {
  135. return false, rows.Err()
  136. }
  137. // WARN: Alougth rows return true, but we may also return error.
  138. types, err := rows.ColumnTypes()
  139. if err != nil {
  140. return true, err
  141. }
  142. fields, err := rows.Columns()
  143. if err != nil {
  144. return true, err
  145. }
  146. switch beanKind {
  147. case reflect.Struct:
  148. if !isScannableStruct(bean, len(types)) {
  149. break
  150. }
  151. return session.getStruct(rows, types, fields, table, bean)
  152. case reflect.Slice:
  153. return session.getSlice(rows, types, fields, bean)
  154. case reflect.Map:
  155. return session.getMap(rows, types, fields, bean)
  156. }
  157. return session.getVars(rows, types, fields, bean)
  158. }
  159. func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
  160. switch t := bean.(type) {
  161. case *[]string:
  162. res, err := session.engine.scanStringInterface(rows, fields, types)
  163. if err != nil {
  164. return true, err
  165. }
  166. var needAppend = len(*t) == 0 // both support slice is empty or has been initlized
  167. for i, r := range res {
  168. if needAppend {
  169. *t = append(*t, r.(*sql.NullString).String)
  170. } else {
  171. (*t)[i] = r.(*sql.NullString).String
  172. }
  173. }
  174. return true, nil
  175. case *[]interface{}:
  176. scanResults, err := session.engine.scanInterfaces(rows, fields, types)
  177. if err != nil {
  178. return true, err
  179. }
  180. var needAppend = len(*t) == 0
  181. for ii := range fields {
  182. s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
  183. if err != nil {
  184. return true, err
  185. }
  186. if needAppend {
  187. *t = append(*t, s)
  188. } else {
  189. (*t)[ii] = s
  190. }
  191. }
  192. return true, nil
  193. default:
  194. return true, fmt.Errorf("unspoorted slice type: %t", t)
  195. }
  196. }
  197. func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
  198. switch t := bean.(type) {
  199. case *map[string]string:
  200. scanResults, err := session.engine.scanStringInterface(rows, fields, types)
  201. if err != nil {
  202. return true, err
  203. }
  204. for ii, key := range fields {
  205. (*t)[key] = scanResults[ii].(*sql.NullString).String
  206. }
  207. return true, nil
  208. case *map[string]interface{}:
  209. scanResults, err := session.engine.scanInterfaces(rows, fields, types)
  210. if err != nil {
  211. return true, err
  212. }
  213. for ii, key := range fields {
  214. s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
  215. if err != nil {
  216. return true, err
  217. }
  218. (*t)[key] = s
  219. }
  220. return true, nil
  221. default:
  222. return true, fmt.Errorf("unspoorted map type: %t", t)
  223. }
  224. }
  225. func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) {
  226. if len(beans) != len(types) {
  227. return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
  228. }
  229. err := session.engine.scan(rows, fields, types, beans...)
  230. return true, err
  231. }
  232. func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
  233. scanResults, err := session.row2Slice(rows, fields, types, bean)
  234. if err != nil {
  235. return false, err
  236. }
  237. // close it before convert data
  238. rows.Close()
  239. dataStruct := utils.ReflectValue(bean)
  240. _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
  241. if err != nil {
  242. return true, err
  243. }
  244. return true, session.executeProcessors()
  245. }
  246. func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
  247. // if has no reftable, then don't use cache currently
  248. if !session.canCache() {
  249. return false, ErrCacheFailed
  250. }
  251. for _, filter := range session.engine.dialect.Filters() {
  252. sqlStr = filter.Do(sqlStr)
  253. }
  254. newsql := session.statement.ConvertIDSQL(sqlStr)
  255. if newsql == "" {
  256. return false, ErrCacheFailed
  257. }
  258. tableName := session.statement.TableName()
  259. cacher := session.engine.cacherMgr.GetCacher(tableName)
  260. session.engine.logger.Debugf("[cache] Get SQL: %s, %v", newsql, args)
  261. table := session.statement.RefTable
  262. ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
  263. if err != nil {
  264. var res = make([]string, len(table.PrimaryKeys))
  265. rows, err := session.NoCache().queryRows(newsql, args...)
  266. if err != nil {
  267. return false, err
  268. }
  269. defer rows.Close()
  270. if rows.Next() {
  271. err = rows.ScanSlice(&res)
  272. if err != nil {
  273. return true, err
  274. }
  275. } else {
  276. if rows.Err() != nil {
  277. return false, rows.Err()
  278. }
  279. return false, ErrCacheFailed
  280. }
  281. var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
  282. for i, col := range table.PKColumns() {
  283. if col.SQLType.IsText() {
  284. pk[i] = res[i]
  285. } else if col.SQLType.IsNumeric() {
  286. n, err := strconv.ParseInt(res[i], 10, 64)
  287. if err != nil {
  288. return false, err
  289. }
  290. pk[i] = n
  291. } else {
  292. return false, errors.New("unsupported")
  293. }
  294. }
  295. ids = []schemas.PK{pk}
  296. session.engine.logger.Debugf("[cache] cache ids: %s, %v", newsql, ids)
  297. err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
  298. if err != nil {
  299. return false, err
  300. }
  301. } else {
  302. session.engine.logger.Debugf("[cache] cache hit: %s, %v", newsql, ids)
  303. }
  304. if len(ids) > 0 {
  305. structValue := reflect.Indirect(reflect.ValueOf(bean))
  306. id := ids[0]
  307. session.engine.logger.Debugf("[cache] get bean: %s, %v", tableName, id)
  308. sid, err := id.ToString()
  309. if err != nil {
  310. return false, err
  311. }
  312. cacheBean := cacher.GetBean(tableName, sid)
  313. if cacheBean == nil {
  314. cacheBean = bean
  315. has, err = session.nocacheGet(reflect.Struct, table, cacheBean, sqlStr, args...)
  316. if err != nil || !has {
  317. return has, err
  318. }
  319. session.engine.logger.Debugf("[cache] cache bean: %s, %v, %v", tableName, id, cacheBean)
  320. cacher.PutBean(tableName, sid, cacheBean)
  321. } else {
  322. session.engine.logger.Debugf("[cache] cache hit: %s, %v, %v", tableName, id, cacheBean)
  323. has = true
  324. }
  325. structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))
  326. return has, nil
  327. }
  328. return false, nil
  329. }