diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go | 73 |
1 files changed, 67 insertions, 6 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go index 64e5e21fbd..d3890af954 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go @@ -4,6 +4,7 @@ package mssql import ( "bytes" + "database/sql" "encoding/binary" "errors" "fmt" @@ -97,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd for columnStrIdx, fieldIdx := range tvpFieldIndexes { field := refStr.Field(fieldIdx) tvpVal := field.Interface() + if tvp.verifyStandardTypeOnNull(buf, tvpVal) { + continue + } valOf := reflect.ValueOf(tvpVal) elemKind := field.Kind() if elemKind == reflect.Ptr && valOf.IsNil() { @@ -155,7 +159,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) { defaultValues = append(defaultValues, v.Interface()) continue } - defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface()) + defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface())) } if columnCount-len(tvpFieldIndexes) == columnCount { @@ -209,19 +213,23 @@ func getSchemeAndName(tvpName string) (string, string, error) { } splitVal := strings.Split(tvpName, ".") if len(splitVal) > 2 { - return "", "", errors.New("wrong tvp name") + return "", "", ErrorObjectName } + const ( + openSquareBrackets = "[" + closeSquareBrackets = "]" + ) if len(splitVal) == 2 { res := make([]string, 2) for key, value := range splitVal { - tmp := strings.Replace(value, "[", "", -1) - tmp = strings.Replace(tmp, "]", "", -1) + tmp := strings.Replace(value, openSquareBrackets, "", -1) + tmp = strings.Replace(tmp, closeSquareBrackets, "", -1) res[key] = tmp } return res[0], res[1], nil } - tmp := strings.Replace(splitVal[0], "[", "", -1) - tmp = strings.Replace(tmp, "]", "", -1) + tmp := strings.Replace(splitVal[0], openSquareBrackets, "", -1) + tmp = strings.Replace(tmp, closeSquareBrackets, "", -1) return "", tmp, nil } @@ -229,3 +237,56 @@ func getSchemeAndName(tvpName string) (string, string, error) { func getCountSQLSeparators(str string) int { return strings.Count(str, sqlSeparator) } + +// verify types https://golang.org/pkg/database/sql/ +func (tvp TVP) createZeroType(fieldVal interface{}) interface{} { + const ( + defaultBool = false + defaultFloat64 = float64(0) + defaultInt64 = int64(0) + defaultString = "" + ) + + switch fieldVal.(type) { + case sql.NullBool: + return defaultBool + case sql.NullFloat64: + return defaultFloat64 + case sql.NullInt64: + return defaultInt64 + case sql.NullString: + return defaultString + } + return fieldVal +} + +// verify types https://golang.org/pkg/database/sql/ +func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) bool { + const ( + defaultNull = uint8(0) + ) + + switch val := tvpVal.(type) { + case sql.NullBool: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullFloat64: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullInt64: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullString: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) + return true + } + } + return false +} |