You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache.go 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. // Copyright (c) Faye Amacker. All rights reserved.
  2. // Licensed under the MIT License. See LICENSE in the project root for license information.
  3. package cbor
  4. import (
  5. "bytes"
  6. "errors"
  7. "reflect"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. )
  13. var (
  14. decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType
  15. encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType
  16. encodeFuncCache sync.Map // map[reflect.Type]encodeFunc
  17. typeInfoCache sync.Map // map[reflect.Type]*typeInfo
  18. )
  19. type specialType int
  20. const (
  21. specialTypeNone specialType = iota
  22. specialTypeUnmarshalerIface
  23. specialTypeEmptyIface
  24. specialTypeTag
  25. specialTypeTime
  26. )
  27. type typeInfo struct {
  28. elemTypeInfo *typeInfo
  29. keyTypeInfo *typeInfo
  30. typ reflect.Type
  31. kind reflect.Kind
  32. nonPtrType reflect.Type
  33. nonPtrKind reflect.Kind
  34. spclType specialType
  35. }
  36. func newTypeInfo(t reflect.Type) *typeInfo {
  37. tInfo := typeInfo{typ: t, kind: t.Kind()}
  38. for t.Kind() == reflect.Ptr {
  39. t = t.Elem()
  40. }
  41. k := t.Kind()
  42. tInfo.nonPtrType = t
  43. tInfo.nonPtrKind = k
  44. if k == reflect.Interface && t.NumMethod() == 0 {
  45. tInfo.spclType = specialTypeEmptyIface
  46. } else if t == typeTag {
  47. tInfo.spclType = specialTypeTag
  48. } else if t == typeTime {
  49. tInfo.spclType = specialTypeTime
  50. } else if reflect.PtrTo(t).Implements(typeUnmarshaler) {
  51. tInfo.spclType = specialTypeUnmarshalerIface
  52. }
  53. switch k {
  54. case reflect.Array, reflect.Slice:
  55. tInfo.elemTypeInfo = getTypeInfo(t.Elem())
  56. case reflect.Map:
  57. tInfo.keyTypeInfo = getTypeInfo(t.Key())
  58. tInfo.elemTypeInfo = getTypeInfo(t.Elem())
  59. }
  60. return &tInfo
  61. }
  62. type decodingStructType struct {
  63. fields fields
  64. err error
  65. toArray bool
  66. }
  67. func getDecodingStructType(t reflect.Type) *decodingStructType {
  68. if v, _ := decodingStructTypeCache.Load(t); v != nil {
  69. return v.(*decodingStructType)
  70. }
  71. flds, structOptions := getFields(t)
  72. toArray := hasToArrayOption(structOptions)
  73. var err error
  74. for i := 0; i < len(flds); i++ {
  75. if flds[i].keyAsInt {
  76. nameAsInt, numErr := strconv.Atoi(flds[i].name)
  77. if numErr != nil {
  78. err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
  79. break
  80. }
  81. flds[i].nameAsInt = int64(nameAsInt)
  82. }
  83. flds[i].typInfo = getTypeInfo(flds[i].typ)
  84. }
  85. structType := &decodingStructType{fields: flds, err: err, toArray: toArray}
  86. decodingStructTypeCache.Store(t, structType)
  87. return structType
  88. }
  89. type encodingStructType struct {
  90. fields fields
  91. bytewiseFields fields
  92. lengthFirstFields fields
  93. err error
  94. toArray bool
  95. omitEmpty bool
  96. hasAnonymousField bool
  97. }
  98. func (st *encodingStructType) getFields(em *encMode) fields {
  99. if em.sort == SortNone {
  100. return st.fields
  101. }
  102. if em.sort == SortLengthFirst {
  103. return st.lengthFirstFields
  104. }
  105. return st.bytewiseFields
  106. }
  107. type bytewiseFieldSorter struct {
  108. fields fields
  109. }
  110. func (x *bytewiseFieldSorter) Len() int {
  111. return len(x.fields)
  112. }
  113. func (x *bytewiseFieldSorter) Swap(i, j int) {
  114. x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
  115. }
  116. func (x *bytewiseFieldSorter) Less(i, j int) bool {
  117. return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
  118. }
  119. type lengthFirstFieldSorter struct {
  120. fields fields
  121. }
  122. func (x *lengthFirstFieldSorter) Len() int {
  123. return len(x.fields)
  124. }
  125. func (x *lengthFirstFieldSorter) Swap(i, j int) {
  126. x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
  127. }
  128. func (x *lengthFirstFieldSorter) Less(i, j int) bool {
  129. if len(x.fields[i].cborName) != len(x.fields[j].cborName) {
  130. return len(x.fields[i].cborName) < len(x.fields[j].cborName)
  131. }
  132. return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
  133. }
  134. func getEncodingStructType(t reflect.Type) *encodingStructType {
  135. if v, _ := encodingStructTypeCache.Load(t); v != nil {
  136. return v.(*encodingStructType)
  137. }
  138. flds, structOptions := getFields(t)
  139. if hasToArrayOption(structOptions) {
  140. return getEncodingStructToArrayType(t, flds)
  141. }
  142. var err error
  143. var omitEmpty bool
  144. var hasAnonymousField bool
  145. var hasKeyAsInt bool
  146. var hasKeyAsStr bool
  147. e := getEncodeState()
  148. for i := 0; i < len(flds); i++ {
  149. // Get field's encodeFunc
  150. flds[i].ef = getEncodeFunc(flds[i].typ)
  151. if flds[i].ef == nil {
  152. err = &UnsupportedTypeError{t}
  153. break
  154. }
  155. // Encode field name
  156. if flds[i].keyAsInt {
  157. nameAsInt, numErr := strconv.Atoi(flds[i].name)
  158. if numErr != nil {
  159. err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
  160. break
  161. }
  162. flds[i].nameAsInt = int64(nameAsInt)
  163. if nameAsInt >= 0 {
  164. encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt))
  165. } else {
  166. n := nameAsInt*(-1) - 1
  167. encodeHead(e, byte(cborTypeNegativeInt), uint64(n))
  168. }
  169. flds[i].cborName = make([]byte, e.Len())
  170. copy(flds[i].cborName, e.Bytes())
  171. e.Reset()
  172. hasKeyAsInt = true
  173. } else {
  174. encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name)))
  175. flds[i].cborName = make([]byte, e.Len()+len(flds[i].name))
  176. n := copy(flds[i].cborName, e.Bytes())
  177. copy(flds[i].cborName[n:], flds[i].name)
  178. e.Reset()
  179. hasKeyAsStr = true
  180. }
  181. // Check if field is from embedded struct
  182. if len(flds[i].idx) > 1 {
  183. hasAnonymousField = true
  184. }
  185. // Check if field can be omitted when empty
  186. if flds[i].omitEmpty {
  187. omitEmpty = true
  188. }
  189. }
  190. putEncodeState(e)
  191. if err != nil {
  192. structType := &encodingStructType{err: err}
  193. encodingStructTypeCache.Store(t, structType)
  194. return structType
  195. }
  196. // Sort fields by canonical order
  197. bytewiseFields := make(fields, len(flds))
  198. copy(bytewiseFields, flds)
  199. sort.Sort(&bytewiseFieldSorter{bytewiseFields})
  200. lengthFirstFields := bytewiseFields
  201. if hasKeyAsInt && hasKeyAsStr {
  202. lengthFirstFields = make(fields, len(flds))
  203. copy(lengthFirstFields, flds)
  204. sort.Sort(&lengthFirstFieldSorter{lengthFirstFields})
  205. }
  206. structType := &encodingStructType{
  207. fields: flds,
  208. bytewiseFields: bytewiseFields,
  209. lengthFirstFields: lengthFirstFields,
  210. omitEmpty: omitEmpty,
  211. hasAnonymousField: hasAnonymousField,
  212. }
  213. encodingStructTypeCache.Store(t, structType)
  214. return structType
  215. }
  216. func getEncodingStructToArrayType(t reflect.Type, flds fields) *encodingStructType {
  217. var hasAnonymousField bool
  218. for i := 0; i < len(flds); i++ {
  219. // Get field's encodeFunc
  220. flds[i].ef = getEncodeFunc(flds[i].typ)
  221. if flds[i].ef == nil {
  222. structType := &encodingStructType{err: &UnsupportedTypeError{t}}
  223. encodingStructTypeCache.Store(t, structType)
  224. return structType
  225. }
  226. // Check if field is from embedded struct
  227. if len(flds[i].idx) > 1 {
  228. hasAnonymousField = true
  229. }
  230. }
  231. structType := &encodingStructType{
  232. fields: flds,
  233. toArray: true,
  234. hasAnonymousField: hasAnonymousField,
  235. }
  236. encodingStructTypeCache.Store(t, structType)
  237. return structType
  238. }
  239. func getEncodeFunc(t reflect.Type) encodeFunc {
  240. if v, _ := encodeFuncCache.Load(t); v != nil {
  241. return v.(encodeFunc)
  242. }
  243. f := getEncodeFuncInternal(t)
  244. encodeFuncCache.Store(t, f)
  245. return f
  246. }
  247. func getTypeInfo(t reflect.Type) *typeInfo {
  248. if v, _ := typeInfoCache.Load(t); v != nil {
  249. return v.(*typeInfo)
  250. }
  251. tInfo := newTypeInfo(t)
  252. typeInfoCache.Store(t, tInfo)
  253. return tInfo
  254. }
  255. func hasToArrayOption(tag string) bool {
  256. s := ",toarray"
  257. idx := strings.Index(tag, s)
  258. return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
  259. }