aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/google.golang.org/protobuf/internal/impl/validate.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/protobuf/internal/impl/validate.go')
-rw-r--r--vendor/google.golang.org/protobuf/internal/impl/validate.go575
1 files changed, 575 insertions, 0 deletions
diff --git a/vendor/google.golang.org/protobuf/internal/impl/validate.go b/vendor/google.golang.org/protobuf/internal/impl/validate.go
new file mode 100644
index 0000000000..39d62fd068
--- /dev/null
+++ b/vendor/google.golang.org/protobuf/internal/impl/validate.go
@@ -0,0 +1,575 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+ "fmt"
+ "math"
+ "math/bits"
+ "reflect"
+ "unicode/utf8"
+
+ "google.golang.org/protobuf/encoding/protowire"
+ "google.golang.org/protobuf/internal/encoding/messageset"
+ "google.golang.org/protobuf/internal/flags"
+ "google.golang.org/protobuf/internal/strs"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+ preg "google.golang.org/protobuf/reflect/protoregistry"
+ piface "google.golang.org/protobuf/runtime/protoiface"
+)
+
+// ValidationStatus is the result of validating the wire-format encoding of a message.
+type ValidationStatus int
+
+const (
+ // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
+ // The validator was unable to render a judgement.
+ //
+ // The only causes of this status are an aberrant message type appearing somewhere
+ // in the message or a failure in the extension resolver.
+ ValidationUnknown ValidationStatus = iota + 1
+
+ // ValidationInvalid indicates that unmarshaling the message will fail.
+ ValidationInvalid
+
+ // ValidationValid indicates that unmarshaling the message will succeed.
+ ValidationValid
+)
+
+func (v ValidationStatus) String() string {
+ switch v {
+ case ValidationUnknown:
+ return "ValidationUnknown"
+ case ValidationInvalid:
+ return "ValidationInvalid"
+ case ValidationValid:
+ return "ValidationValid"
+ default:
+ return fmt.Sprintf("ValidationStatus(%d)", int(v))
+ }
+}
+
+// Validate determines whether the contents of the buffer are a valid wire encoding
+// of the message type.
+//
+// This function is exposed for testing.
+func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) {
+ mi, ok := mt.(*MessageInfo)
+ if !ok {
+ return out, ValidationUnknown
+ }
+ if in.Resolver == nil {
+ in.Resolver = preg.GlobalTypes
+ }
+ o, st := mi.validate(in.Buf, 0, unmarshalOptions{
+ flags: in.Flags,
+ resolver: in.Resolver,
+ })
+ if o.initialized {
+ out.Flags |= piface.UnmarshalInitialized
+ }
+ return out, st
+}
+
+type validationInfo struct {
+ mi *MessageInfo
+ typ validationType
+ keyType, valType validationType
+
+ // For non-required fields, requiredBit is 0.
+ //
+ // For required fields, requiredBit's nth bit is set, where n is a
+ // unique index in the range [0, MessageInfo.numRequiredFields).
+ //
+ // If there are more than 64 required fields, requiredBit is 0.
+ requiredBit uint64
+}
+
+type validationType uint8
+
+const (
+ validationTypeOther validationType = iota
+ validationTypeMessage
+ validationTypeGroup
+ validationTypeMap
+ validationTypeRepeatedVarint
+ validationTypeRepeatedFixed32
+ validationTypeRepeatedFixed64
+ validationTypeVarint
+ validationTypeFixed32
+ validationTypeFixed64
+ validationTypeBytes
+ validationTypeUTF8String
+ validationTypeMessageSetItem
+)
+
+func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
+ var vi validationInfo
+ switch {
+ case fd.ContainingOneof() != nil:
+ switch fd.Kind() {
+ case pref.MessageKind:
+ vi.typ = validationTypeMessage
+ if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
+ vi.mi = getMessageInfo(ot.Field(0).Type)
+ }
+ case pref.GroupKind:
+ vi.typ = validationTypeGroup
+ if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
+ vi.mi = getMessageInfo(ot.Field(0).Type)
+ }
+ case pref.StringKind:
+ if strs.EnforceUTF8(fd) {
+ vi.typ = validationTypeUTF8String
+ }
+ }
+ default:
+ vi = newValidationInfo(fd, ft)
+ }
+ if fd.Cardinality() == pref.Required {
+ // Avoid overflow. The required field check is done with a 64-bit mask, with
+ // any message containing more than 64 required fields always reported as
+ // potentially uninitialized, so it is not important to get a precise count
+ // of the required fields past 64.
+ if mi.numRequiredFields < math.MaxUint8 {
+ mi.numRequiredFields++
+ vi.requiredBit = 1 << (mi.numRequiredFields - 1)
+ }
+ }
+ return vi
+}
+
+func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
+ var vi validationInfo
+ switch {
+ case fd.IsList():
+ switch fd.Kind() {
+ case pref.MessageKind:
+ vi.typ = validationTypeMessage
+ if ft.Kind() == reflect.Slice {
+ vi.mi = getMessageInfo(ft.Elem())
+ }
+ case pref.GroupKind:
+ vi.typ = validationTypeGroup
+ if ft.Kind() == reflect.Slice {
+ vi.mi = getMessageInfo(ft.Elem())
+ }
+ case pref.StringKind:
+ vi.typ = validationTypeBytes
+ if strs.EnforceUTF8(fd) {
+ vi.typ = validationTypeUTF8String
+ }
+ default:
+ switch wireTypes[fd.Kind()] {
+ case protowire.VarintType:
+ vi.typ = validationTypeRepeatedVarint
+ case protowire.Fixed32Type:
+ vi.typ = validationTypeRepeatedFixed32
+ case protowire.Fixed64Type:
+ vi.typ = validationTypeRepeatedFixed64
+ }
+ }
+ case fd.IsMap():
+ vi.typ = validationTypeMap
+ switch fd.MapKey().Kind() {
+ case pref.StringKind:
+ if strs.EnforceUTF8(fd) {
+ vi.keyType = validationTypeUTF8String
+ }
+ }
+ switch fd.MapValue().Kind() {
+ case pref.MessageKind:
+ vi.valType = validationTypeMessage
+ if ft.Kind() == reflect.Map {
+ vi.mi = getMessageInfo(ft.Elem())
+ }
+ case pref.StringKind:
+ if strs.EnforceUTF8(fd) {
+ vi.valType = validationTypeUTF8String
+ }
+ }
+ default:
+ switch fd.Kind() {
+ case pref.MessageKind:
+ vi.typ = validationTypeMessage
+ if !fd.IsWeak() {
+ vi.mi = getMessageInfo(ft)
+ }
+ case pref.GroupKind:
+ vi.typ = validationTypeGroup
+ vi.mi = getMessageInfo(ft)
+ case pref.StringKind:
+ vi.typ = validationTypeBytes
+ if strs.EnforceUTF8(fd) {
+ vi.typ = validationTypeUTF8String
+ }
+ default:
+ switch wireTypes[fd.Kind()] {
+ case protowire.VarintType:
+ vi.typ = validationTypeVarint
+ case protowire.Fixed32Type:
+ vi.typ = validationTypeFixed32
+ case protowire.Fixed64Type:
+ vi.typ = validationTypeFixed64
+ case protowire.BytesType:
+ vi.typ = validationTypeBytes
+ }
+ }
+ }
+ return vi
+}
+
+func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
+ mi.init()
+ type validationState struct {
+ typ validationType
+ keyType, valType validationType
+ endGroup protowire.Number
+ mi *MessageInfo
+ tail []byte
+ requiredMask uint64
+ }
+
+ // Pre-allocate some slots to avoid repeated slice reallocation.
+ states := make([]validationState, 0, 16)
+ states = append(states, validationState{
+ typ: validationTypeMessage,
+ mi: mi,
+ })
+ if groupTag > 0 {
+ states[0].typ = validationTypeGroup
+ states[0].endGroup = groupTag
+ }
+ initialized := true
+ start := len(b)
+State:
+ for len(states) > 0 {
+ st := &states[len(states)-1]
+ for len(b) > 0 {
+ // Parse the tag (field number and wire type).
+ var tag uint64
+ if b[0] < 0x80 {
+ tag = uint64(b[0])
+ b = b[1:]
+ } else if len(b) >= 2 && b[1] < 128 {
+ tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
+ b = b[2:]
+ } else {
+ var n int
+ tag, n = protowire.ConsumeVarint(b)
+ if n < 0 {
+ return out, ValidationInvalid
+ }
+ b = b[n:]
+ }
+ var num protowire.Number
+ if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
+ return out, ValidationInvalid
+ } else {
+ num = protowire.Number(n)
+ }
+ wtyp := protowire.Type(tag & 7)
+
+ if wtyp == protowire.EndGroupType {
+ if st.endGroup == num {
+ goto PopState
+ }
+ return out, ValidationInvalid
+ }
+ var vi validationInfo
+ switch {
+ case st.typ == validationTypeMap:
+ switch num {
+ case 1:
+ vi.typ = st.keyType
+ case 2:
+ vi.typ = st.valType
+ vi.mi = st.mi
+ vi.requiredBit = 1
+ }
+ case flags.ProtoLegacy && st.mi.isMessageSet:
+ switch num {
+ case messageset.FieldItem:
+ vi.typ = validationTypeMessageSetItem
+ }
+ default:
+ var f *coderFieldInfo
+ if int(num) < len(st.mi.denseCoderFields) {
+ f = st.mi.denseCoderFields[num]
+ } else {
+ f = st.mi.coderFields[num]
+ }
+ if f != nil {
+ vi = f.validation
+ if vi.typ == validationTypeMessage && vi.mi == nil {
+ // Probable weak field.
+ //
+ // TODO: Consider storing the results of this lookup somewhere
+ // rather than recomputing it on every validation.
+ fd := st.mi.Desc.Fields().ByNumber(num)
+ if fd == nil || !fd.IsWeak() {
+ break
+ }
+ messageName := fd.Message().FullName()
+ messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
+ switch err {
+ case nil:
+ vi.mi, _ = messageType.(*MessageInfo)
+ case preg.NotFound:
+ vi.typ = validationTypeBytes
+ default:
+ return out, ValidationUnknown
+ }
+ }
+ break
+ }
+ // Possible extension field.
+ //
+ // TODO: We should return ValidationUnknown when:
+ // 1. The resolver is not frozen. (More extensions may be added to it.)
+ // 2. The resolver returns preg.NotFound.
+ // In this case, a type added to the resolver in the future could cause
+ // unmarshaling to begin failing. Supporting this requires some way to
+ // determine if the resolver is frozen.
+ xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
+ if err != nil && err != preg.NotFound {
+ return out, ValidationUnknown
+ }
+ if err == nil {
+ vi = getExtensionFieldInfo(xt).validation
+ }
+ }
+ if vi.requiredBit != 0 {
+ // Check that the field has a compatible wire type.
+ // We only need to consider non-repeated field types,
+ // since repeated fields (and maps) can never be required.
+ ok := false
+ switch vi.typ {
+ case validationTypeVarint:
+ ok = wtyp == protowire.VarintType
+ case validationTypeFixed32:
+ ok = wtyp == protowire.Fixed32Type
+ case validationTypeFixed64:
+ ok = wtyp == protowire.Fixed64Type
+ case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
+ ok = wtyp == protowire.BytesType
+ case validationTypeGroup:
+ ok = wtyp == protowire.StartGroupType
+ }
+ if ok {
+ st.requiredMask |= vi.requiredBit
+ }
+ }
+
+ switch wtyp {
+ case protowire.VarintType:
+ if len(b) >= 10 {
+ switch {
+ case b[0] < 0x80:
+ b = b[1:]
+ case b[1] < 0x80:
+ b = b[2:]
+ case b[2] < 0x80:
+ b = b[3:]
+ case b[3] < 0x80:
+ b = b[4:]
+ case b[4] < 0x80:
+ b = b[5:]
+ case b[5] < 0x80:
+ b = b[6:]
+ case b[6] < 0x80:
+ b = b[7:]
+ case b[7] < 0x80:
+ b = b[8:]
+ case b[8] < 0x80:
+ b = b[9:]
+ case b[9] < 0x80 && b[9] < 2:
+ b = b[10:]
+ default:
+ return out, ValidationInvalid
+ }
+ } else {
+ switch {
+ case len(b) > 0 && b[0] < 0x80:
+ b = b[1:]
+ case len(b) > 1 && b[1] < 0x80:
+ b = b[2:]
+ case len(b) > 2 && b[2] < 0x80:
+ b = b[3:]
+ case len(b) > 3 && b[3] < 0x80:
+ b = b[4:]
+ case len(b) > 4 && b[4] < 0x80:
+ b = b[5:]
+ case len(b) > 5 && b[5] < 0x80:
+ b = b[6:]
+ case len(b) > 6 && b[6] < 0x80:
+ b = b[7:]
+ case len(b) > 7 && b[7] < 0x80:
+ b = b[8:]
+ case len(b) > 8 && b[8] < 0x80:
+ b = b[9:]
+ case len(b) > 9 && b[9] < 2:
+ b = b[10:]
+ default:
+ return out, ValidationInvalid
+ }
+ }
+ continue State
+ case protowire.BytesType:
+ var size uint64
+ if len(b) >= 1 && b[0] < 0x80 {
+ size = uint64(b[0])
+ b = b[1:]
+ } else if len(b) >= 2 && b[1] < 128 {
+ size = uint64(b[0]&0x7f) + uint64(b[1])<<7
+ b = b[2:]
+ } else {
+ var n int
+ size, n = protowire.ConsumeVarint(b)
+ if n < 0 {
+ return out, ValidationInvalid
+ }
+ b = b[n:]
+ }
+ if size > uint64(len(b)) {
+ return out, ValidationInvalid
+ }
+ v := b[:size]
+ b = b[size:]
+ switch vi.typ {
+ case validationTypeMessage:
+ if vi.mi == nil {
+ return out, ValidationUnknown
+ }
+ vi.mi.init()
+ fallthrough
+ case validationTypeMap:
+ if vi.mi != nil {
+ vi.mi.init()
+ }
+ states = append(states, validationState{
+ typ: vi.typ,
+ keyType: vi.keyType,
+ valType: vi.valType,
+ mi: vi.mi,
+ tail: b,
+ })
+ b = v
+ continue State
+ case validationTypeRepeatedVarint:
+ // Packed field.
+ for len(v) > 0 {
+ _, n := protowire.ConsumeVarint(v)
+ if n < 0 {
+ return out, ValidationInvalid
+ }
+ v = v[n:]
+ }
+ case validationTypeRepeatedFixed32:
+ // Packed field.
+ if len(v)%4 != 0 {
+ return out, ValidationInvalid
+ }
+ case validationTypeRepeatedFixed64:
+ // Packed field.
+ if len(v)%8 != 0 {
+ return out, ValidationInvalid
+ }
+ case validationTypeUTF8String:
+ if !utf8.Valid(v) {
+ return out, ValidationInvalid
+ }
+ }
+ case protowire.Fixed32Type:
+ if len(b) < 4 {
+ return out, ValidationInvalid
+ }
+ b = b[4:]
+ case protowire.Fixed64Type:
+ if len(b) < 8 {
+ return out, ValidationInvalid
+ }
+ b = b[8:]
+ case protowire.StartGroupType:
+ switch {
+ case vi.typ == validationTypeGroup:
+ if vi.mi == nil {
+ return out, ValidationUnknown
+ }
+ vi.mi.init()
+ states = append(states, validationState{
+ typ: validationTypeGroup,
+ mi: vi.mi,
+ endGroup: num,
+ })
+ continue State
+ case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
+ typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
+ if err != nil {
+ return out, ValidationInvalid
+ }
+ xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
+ switch {
+ case err == preg.NotFound:
+ b = b[n:]
+ case err != nil:
+ return out, ValidationUnknown
+ default:
+ xvi := getExtensionFieldInfo(xt).validation
+ if xvi.mi != nil {
+ xvi.mi.init()
+ }
+ states = append(states, validationState{
+ typ: xvi.typ,
+ mi: xvi.mi,
+ tail: b[n:],
+ })
+ b = v
+ continue State
+ }
+ default:
+ n := protowire.ConsumeFieldValue(num, wtyp, b)
+ if n < 0 {
+ return out, ValidationInvalid
+ }
+ b = b[n:]
+ }
+ default:
+ return out, ValidationInvalid
+ }
+ }
+ if st.endGroup != 0 {
+ return out, ValidationInvalid
+ }
+ if len(b) != 0 {
+ return out, ValidationInvalid
+ }
+ b = st.tail
+ PopState:
+ numRequiredFields := 0
+ switch st.typ {
+ case validationTypeMessage, validationTypeGroup:
+ numRequiredFields = int(st.mi.numRequiredFields)
+ case validationTypeMap:
+ // If this is a map field with a message value that contains
+ // required fields, require that the value be present.
+ if st.mi != nil && st.mi.numRequiredFields > 0 {
+ numRequiredFields = 1
+ }
+ }
+ // If there are more than 64 required fields, this check will
+ // always fail and we will report that the message is potentially
+ // uninitialized.
+ if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
+ initialized = false
+ }
+ states = states[:len(states)-1]
+ }
+ out.n = start - len(b)
+ if initialized {
+ out.initialized = true
+ }
+ return out, ValidationValid
+}