123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- // Copyright (c) Faye Amacker. All rights reserved.
- // Licensed under the MIT License. See LICENSE in the project root for license information.
-
- package cbor
-
- import (
- "bytes"
- "errors"
- "reflect"
- "sort"
- "strconv"
- "strings"
- "sync"
- )
-
- var (
- decodingStructTypeCache sync.Map // map[reflect.Type]*decodingStructType
- encodingStructTypeCache sync.Map // map[reflect.Type]*encodingStructType
- encodeFuncCache sync.Map // map[reflect.Type]encodeFunc
- typeInfoCache sync.Map // map[reflect.Type]*typeInfo
- )
-
- type specialType int
-
- const (
- specialTypeNone specialType = iota
- specialTypeUnmarshalerIface
- specialTypeEmptyIface
- specialTypeTag
- specialTypeTime
- )
-
- type typeInfo struct {
- elemTypeInfo *typeInfo
- keyTypeInfo *typeInfo
- typ reflect.Type
- kind reflect.Kind
- nonPtrType reflect.Type
- nonPtrKind reflect.Kind
- spclType specialType
- }
-
- func newTypeInfo(t reflect.Type) *typeInfo {
- tInfo := typeInfo{typ: t, kind: t.Kind()}
-
- for t.Kind() == reflect.Ptr {
- t = t.Elem()
- }
-
- k := t.Kind()
-
- tInfo.nonPtrType = t
- tInfo.nonPtrKind = k
-
- if k == reflect.Interface && t.NumMethod() == 0 {
- tInfo.spclType = specialTypeEmptyIface
- } else if t == typeTag {
- tInfo.spclType = specialTypeTag
- } else if t == typeTime {
- tInfo.spclType = specialTypeTime
- } else if reflect.PtrTo(t).Implements(typeUnmarshaler) {
- tInfo.spclType = specialTypeUnmarshalerIface
- }
-
- switch k {
- case reflect.Array, reflect.Slice:
- tInfo.elemTypeInfo = getTypeInfo(t.Elem())
- case reflect.Map:
- tInfo.keyTypeInfo = getTypeInfo(t.Key())
- tInfo.elemTypeInfo = getTypeInfo(t.Elem())
- }
-
- return &tInfo
- }
-
- type decodingStructType struct {
- fields fields
- err error
- toArray bool
- }
-
- func getDecodingStructType(t reflect.Type) *decodingStructType {
- if v, _ := decodingStructTypeCache.Load(t); v != nil {
- return v.(*decodingStructType)
- }
-
- flds, structOptions := getFields(t)
-
- toArray := hasToArrayOption(structOptions)
-
- var err error
- for i := 0; i < len(flds); i++ {
- if flds[i].keyAsInt {
- nameAsInt, numErr := strconv.Atoi(flds[i].name)
- if numErr != nil {
- err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
- break
- }
- flds[i].nameAsInt = int64(nameAsInt)
- }
-
- flds[i].typInfo = getTypeInfo(flds[i].typ)
- }
-
- structType := &decodingStructType{fields: flds, err: err, toArray: toArray}
- decodingStructTypeCache.Store(t, structType)
- return structType
- }
-
- type encodingStructType struct {
- fields fields
- bytewiseFields fields
- lengthFirstFields fields
- err error
- toArray bool
- omitEmpty bool
- hasAnonymousField bool
- }
-
- func (st *encodingStructType) getFields(em *encMode) fields {
- if em.sort == SortNone {
- return st.fields
- }
- if em.sort == SortLengthFirst {
- return st.lengthFirstFields
- }
- return st.bytewiseFields
- }
-
- type bytewiseFieldSorter struct {
- fields fields
- }
-
- func (x *bytewiseFieldSorter) Len() int {
- return len(x.fields)
- }
-
- func (x *bytewiseFieldSorter) Swap(i, j int) {
- x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
- }
-
- func (x *bytewiseFieldSorter) Less(i, j int) bool {
- return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
- }
-
- type lengthFirstFieldSorter struct {
- fields fields
- }
-
- func (x *lengthFirstFieldSorter) Len() int {
- return len(x.fields)
- }
-
- func (x *lengthFirstFieldSorter) Swap(i, j int) {
- x.fields[i], x.fields[j] = x.fields[j], x.fields[i]
- }
-
- func (x *lengthFirstFieldSorter) Less(i, j int) bool {
- if len(x.fields[i].cborName) != len(x.fields[j].cborName) {
- return len(x.fields[i].cborName) < len(x.fields[j].cborName)
- }
- return bytes.Compare(x.fields[i].cborName, x.fields[j].cborName) <= 0
- }
-
- func getEncodingStructType(t reflect.Type) *encodingStructType {
- if v, _ := encodingStructTypeCache.Load(t); v != nil {
- return v.(*encodingStructType)
- }
-
- flds, structOptions := getFields(t)
-
- if hasToArrayOption(structOptions) {
- return getEncodingStructToArrayType(t, flds)
- }
-
- var err error
- var omitEmpty bool
- var hasAnonymousField bool
- var hasKeyAsInt bool
- var hasKeyAsStr bool
- e := getEncodeState()
- for i := 0; i < len(flds); i++ {
- // Get field's encodeFunc
- flds[i].ef = getEncodeFunc(flds[i].typ)
- if flds[i].ef == nil {
- err = &UnsupportedTypeError{t}
- break
- }
-
- // Encode field name
- if flds[i].keyAsInt {
- nameAsInt, numErr := strconv.Atoi(flds[i].name)
- if numErr != nil {
- err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
- break
- }
- flds[i].nameAsInt = int64(nameAsInt)
- if nameAsInt >= 0 {
- encodeHead(e, byte(cborTypePositiveInt), uint64(nameAsInt))
- } else {
- n := nameAsInt*(-1) - 1
- encodeHead(e, byte(cborTypeNegativeInt), uint64(n))
- }
- flds[i].cborName = make([]byte, e.Len())
- copy(flds[i].cborName, e.Bytes())
- e.Reset()
-
- hasKeyAsInt = true
- } else {
- encodeHead(e, byte(cborTypeTextString), uint64(len(flds[i].name)))
- flds[i].cborName = make([]byte, e.Len()+len(flds[i].name))
- n := copy(flds[i].cborName, e.Bytes())
- copy(flds[i].cborName[n:], flds[i].name)
- e.Reset()
-
- hasKeyAsStr = true
- }
-
- // Check if field is from embedded struct
- if len(flds[i].idx) > 1 {
- hasAnonymousField = true
- }
-
- // Check if field can be omitted when empty
- if flds[i].omitEmpty {
- omitEmpty = true
- }
- }
- putEncodeState(e)
-
- if err != nil {
- structType := &encodingStructType{err: err}
- encodingStructTypeCache.Store(t, structType)
- return structType
- }
-
- // Sort fields by canonical order
- bytewiseFields := make(fields, len(flds))
- copy(bytewiseFields, flds)
- sort.Sort(&bytewiseFieldSorter{bytewiseFields})
-
- lengthFirstFields := bytewiseFields
- if hasKeyAsInt && hasKeyAsStr {
- lengthFirstFields = make(fields, len(flds))
- copy(lengthFirstFields, flds)
- sort.Sort(&lengthFirstFieldSorter{lengthFirstFields})
- }
-
- structType := &encodingStructType{
- fields: flds,
- bytewiseFields: bytewiseFields,
- lengthFirstFields: lengthFirstFields,
- omitEmpty: omitEmpty,
- hasAnonymousField: hasAnonymousField,
- }
- encodingStructTypeCache.Store(t, structType)
- return structType
- }
-
- func getEncodingStructToArrayType(t reflect.Type, flds fields) *encodingStructType {
- var hasAnonymousField bool
- for i := 0; i < len(flds); i++ {
- // Get field's encodeFunc
- flds[i].ef = getEncodeFunc(flds[i].typ)
- if flds[i].ef == nil {
- structType := &encodingStructType{err: &UnsupportedTypeError{t}}
- encodingStructTypeCache.Store(t, structType)
- return structType
- }
-
- // Check if field is from embedded struct
- if len(flds[i].idx) > 1 {
- hasAnonymousField = true
- }
- }
-
- structType := &encodingStructType{
- fields: flds,
- toArray: true,
- hasAnonymousField: hasAnonymousField,
- }
- encodingStructTypeCache.Store(t, structType)
- return structType
- }
-
- func getEncodeFunc(t reflect.Type) encodeFunc {
- if v, _ := encodeFuncCache.Load(t); v != nil {
- return v.(encodeFunc)
- }
- f := getEncodeFuncInternal(t)
- encodeFuncCache.Store(t, f)
- return f
- }
-
- func getTypeInfo(t reflect.Type) *typeInfo {
- if v, _ := typeInfoCache.Load(t); v != nil {
- return v.(*typeInfo)
- }
- tInfo := newTypeInfo(t)
- typeInfoCache.Store(t, tInfo)
- return tInfo
- }
-
- func hasToArrayOption(tag string) bool {
- s := ",toarray"
- idx := strings.Index(tag, s)
- return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
- }
|