You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

struct_codec.go 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. // Copyright (C) MongoDB, Inc. 2017-present.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"); you may
  4. // not use this file except in compliance with the License. You may obtain
  5. // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
  6. package bsoncodec
  7. import (
  8. "errors"
  9. "fmt"
  10. "reflect"
  11. "sync"
  12. "go.mongodb.org/mongo-driver/bson/bsonrw"
  13. "go.mongodb.org/mongo-driver/bson/bsontype"
  14. )
  15. var defaultStructCodec = &StructCodec{
  16. cache: make(map[reflect.Type]*structDescription),
  17. parser: DefaultStructTagParser,
  18. }
  19. // Zeroer allows custom struct types to implement a report of zero
  20. // state. All struct types that don't implement Zeroer or where IsZero
  21. // returns false are considered to be not zero.
  22. type Zeroer interface {
  23. IsZero() bool
  24. }
  25. // StructCodec is the Codec used for struct values.
  26. type StructCodec struct {
  27. cache map[reflect.Type]*structDescription
  28. l sync.RWMutex
  29. parser StructTagParser
  30. }
  31. var _ ValueEncoder = &StructCodec{}
  32. var _ ValueDecoder = &StructCodec{}
  33. // NewStructCodec returns a StructCodec that uses p for struct tag parsing.
  34. func NewStructCodec(p StructTagParser) (*StructCodec, error) {
  35. if p == nil {
  36. return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
  37. }
  38. return &StructCodec{
  39. cache: make(map[reflect.Type]*structDescription),
  40. parser: p,
  41. }, nil
  42. }
  43. // EncodeValue handles encoding generic struct types.
  44. func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
  45. if !val.IsValid() || val.Kind() != reflect.Struct {
  46. return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
  47. }
  48. sd, err := sc.describeStruct(r.Registry, val.Type())
  49. if err != nil {
  50. return err
  51. }
  52. dw, err := vw.WriteDocument()
  53. if err != nil {
  54. return err
  55. }
  56. var rv reflect.Value
  57. for _, desc := range sd.fl {
  58. if desc.inline == nil {
  59. rv = val.Field(desc.idx)
  60. } else {
  61. rv = val.FieldByIndex(desc.inline)
  62. }
  63. if desc.encoder == nil {
  64. return ErrNoEncoder{Type: rv.Type()}
  65. }
  66. encoder := desc.encoder
  67. iszero := sc.isZero
  68. if iz, ok := encoder.(CodecZeroer); ok {
  69. iszero = iz.IsTypeZero
  70. }
  71. if desc.omitEmpty && iszero(rv.Interface()) {
  72. continue
  73. }
  74. vw2, err := dw.WriteDocumentElement(desc.name)
  75. if err != nil {
  76. return err
  77. }
  78. ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
  79. err = encoder.EncodeValue(ectx, vw2, rv)
  80. if err != nil {
  81. return err
  82. }
  83. }
  84. if sd.inlineMap >= 0 {
  85. rv := val.Field(sd.inlineMap)
  86. collisionFn := func(key string) bool {
  87. _, exists := sd.fm[key]
  88. return exists
  89. }
  90. return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn)
  91. }
  92. return dw.WriteDocumentEnd()
  93. }
  94. // DecodeValue implements the Codec interface.
  95. // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
  96. // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
  97. func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
  98. if !val.CanSet() || val.Kind() != reflect.Struct {
  99. return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
  100. }
  101. switch vr.Type() {
  102. case bsontype.Type(0), bsontype.EmbeddedDocument:
  103. default:
  104. return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
  105. }
  106. sd, err := sc.describeStruct(r.Registry, val.Type())
  107. if err != nil {
  108. return err
  109. }
  110. var decoder ValueDecoder
  111. var inlineMap reflect.Value
  112. if sd.inlineMap >= 0 {
  113. inlineMap = val.Field(sd.inlineMap)
  114. if inlineMap.IsNil() {
  115. inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
  116. }
  117. decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
  118. if err != nil {
  119. return err
  120. }
  121. }
  122. dr, err := vr.ReadDocument()
  123. if err != nil {
  124. return err
  125. }
  126. for {
  127. name, vr, err := dr.ReadElement()
  128. if err == bsonrw.ErrEOD {
  129. break
  130. }
  131. if err != nil {
  132. return err
  133. }
  134. fd, exists := sd.fm[name]
  135. if !exists {
  136. if sd.inlineMap < 0 {
  137. // The encoding/json package requires a flag to return on error for non-existent fields.
  138. // This functionality seems appropriate for the struct codec.
  139. err = vr.Skip()
  140. if err != nil {
  141. return err
  142. }
  143. continue
  144. }
  145. elem := reflect.New(inlineMap.Type().Elem()).Elem()
  146. err = decoder.DecodeValue(r, vr, elem)
  147. if err != nil {
  148. return err
  149. }
  150. inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
  151. continue
  152. }
  153. var field reflect.Value
  154. if fd.inline == nil {
  155. field = val.Field(fd.idx)
  156. } else {
  157. field = val.FieldByIndex(fd.inline)
  158. }
  159. if !field.CanSet() { // Being settable is a super set of being addressable.
  160. return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field)
  161. }
  162. if field.Kind() == reflect.Ptr && field.IsNil() {
  163. field.Set(reflect.New(field.Type().Elem()))
  164. }
  165. field = field.Addr()
  166. dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate}
  167. if fd.decoder == nil {
  168. return ErrNoDecoder{Type: field.Elem().Type()}
  169. }
  170. if decoder, ok := fd.decoder.(ValueDecoder); ok {
  171. err = decoder.DecodeValue(dctx, vr, field.Elem())
  172. if err != nil {
  173. return err
  174. }
  175. continue
  176. }
  177. err = fd.decoder.DecodeValue(dctx, vr, field)
  178. if err != nil {
  179. return err
  180. }
  181. }
  182. return nil
  183. }
  184. func (sc *StructCodec) isZero(i interface{}) bool {
  185. v := reflect.ValueOf(i)
  186. // check the value validity
  187. if !v.IsValid() {
  188. return true
  189. }
  190. if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
  191. return z.IsZero()
  192. }
  193. switch v.Kind() {
  194. case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
  195. return v.Len() == 0
  196. case reflect.Bool:
  197. return !v.Bool()
  198. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  199. return v.Int() == 0
  200. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  201. return v.Uint() == 0
  202. case reflect.Float32, reflect.Float64:
  203. return v.Float() == 0
  204. case reflect.Interface, reflect.Ptr:
  205. return v.IsNil()
  206. }
  207. return false
  208. }
  209. type structDescription struct {
  210. fm map[string]fieldDescription
  211. fl []fieldDescription
  212. inlineMap int
  213. }
  214. type fieldDescription struct {
  215. name string
  216. idx int
  217. omitEmpty bool
  218. minSize bool
  219. truncate bool
  220. inline []int
  221. encoder ValueEncoder
  222. decoder ValueDecoder
  223. }
  224. func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
  225. // We need to analyze the struct, including getting the tags, collecting
  226. // information about inlining, and create a map of the field name to the field.
  227. sc.l.RLock()
  228. ds, exists := sc.cache[t]
  229. sc.l.RUnlock()
  230. if exists {
  231. return ds, nil
  232. }
  233. numFields := t.NumField()
  234. sd := &structDescription{
  235. fm: make(map[string]fieldDescription, numFields),
  236. fl: make([]fieldDescription, 0, numFields),
  237. inlineMap: -1,
  238. }
  239. for i := 0; i < numFields; i++ {
  240. sf := t.Field(i)
  241. if sf.PkgPath != "" {
  242. // unexported, ignore
  243. continue
  244. }
  245. encoder, err := r.LookupEncoder(sf.Type)
  246. if err != nil {
  247. encoder = nil
  248. }
  249. decoder, err := r.LookupDecoder(sf.Type)
  250. if err != nil {
  251. decoder = nil
  252. }
  253. description := fieldDescription{idx: i, encoder: encoder, decoder: decoder}
  254. stags, err := sc.parser.ParseStructTags(sf)
  255. if err != nil {
  256. return nil, err
  257. }
  258. if stags.Skip {
  259. continue
  260. }
  261. description.name = stags.Name
  262. description.omitEmpty = stags.OmitEmpty
  263. description.minSize = stags.MinSize
  264. description.truncate = stags.Truncate
  265. if stags.Inline {
  266. switch sf.Type.Kind() {
  267. case reflect.Map:
  268. if sd.inlineMap >= 0 {
  269. return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
  270. }
  271. if sf.Type.Key() != tString {
  272. return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
  273. }
  274. sd.inlineMap = description.idx
  275. case reflect.Struct:
  276. inlinesf, err := sc.describeStruct(r, sf.Type)
  277. if err != nil {
  278. return nil, err
  279. }
  280. for _, fd := range inlinesf.fl {
  281. if _, exists := sd.fm[fd.name]; exists {
  282. return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name)
  283. }
  284. if fd.inline == nil {
  285. fd.inline = []int{i, fd.idx}
  286. } else {
  287. fd.inline = append([]int{i}, fd.inline...)
  288. }
  289. sd.fm[fd.name] = fd
  290. sd.fl = append(sd.fl, fd)
  291. }
  292. default:
  293. return nil, fmt.Errorf("(struct %s) inline fields must be either a struct or a map", t.String())
  294. }
  295. continue
  296. }
  297. if _, exists := sd.fm[description.name]; exists {
  298. return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name)
  299. }
  300. sd.fm[description.name] = description
  301. sd.fl = append(sd.fl, description)
  302. }
  303. sc.l.Lock()
  304. sc.cache[t] = sd
  305. sc.l.Unlock()
  306. return sd, nil
  307. }