summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go73
1 files changed, 67 insertions, 6 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go
index 64e5e21fbd..d3890af954 100644
--- a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go
+++ b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go
@@ -4,6 +4,7 @@ package mssql
import (
"bytes"
+ "database/sql"
"encoding/binary"
"errors"
"fmt"
@@ -97,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
field := refStr.Field(fieldIdx)
tvpVal := field.Interface()
+ if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
+ continue
+ }
valOf := reflect.ValueOf(tvpVal)
elemKind := field.Kind()
if elemKind == reflect.Ptr && valOf.IsNil() {
@@ -155,7 +159,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
defaultValues = append(defaultValues, v.Interface())
continue
}
- defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface())
+ defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface()))
}
if columnCount-len(tvpFieldIndexes) == columnCount {
@@ -209,19 +213,23 @@ func getSchemeAndName(tvpName string) (string, string, error) {
}
splitVal := strings.Split(tvpName, ".")
if len(splitVal) > 2 {
- return "", "", errors.New("wrong tvp name")
+ return "", "", ErrorObjectName
}
+ const (
+ openSquareBrackets = "["
+ closeSquareBrackets = "]"
+ )
if len(splitVal) == 2 {
res := make([]string, 2)
for key, value := range splitVal {
- tmp := strings.Replace(value, "[", "", -1)
- tmp = strings.Replace(tmp, "]", "", -1)
+ tmp := strings.Replace(value, openSquareBrackets, "", -1)
+ tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
res[key] = tmp
}
return res[0], res[1], nil
}
- tmp := strings.Replace(splitVal[0], "[", "", -1)
- tmp = strings.Replace(tmp, "]", "", -1)
+ tmp := strings.Replace(splitVal[0], openSquareBrackets, "", -1)
+ tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
return "", tmp, nil
}
@@ -229,3 +237,56 @@ func getSchemeAndName(tvpName string) (string, string, error) {
func getCountSQLSeparators(str string) int {
return strings.Count(str, sqlSeparator)
}
+
+// verify types https://golang.org/pkg/database/sql/
+func (tvp TVP) createZeroType(fieldVal interface{}) interface{} {
+ const (
+ defaultBool = false
+ defaultFloat64 = float64(0)
+ defaultInt64 = int64(0)
+ defaultString = ""
+ )
+
+ switch fieldVal.(type) {
+ case sql.NullBool:
+ return defaultBool
+ case sql.NullFloat64:
+ return defaultFloat64
+ case sql.NullInt64:
+ return defaultInt64
+ case sql.NullString:
+ return defaultString
+ }
+ return fieldVal
+}
+
+// verify types https://golang.org/pkg/database/sql/
+func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) bool {
+ const (
+ defaultNull = uint8(0)
+ )
+
+ switch val := tvpVal.(type) {
+ case sql.NullBool:
+ if !val.Valid {
+ binary.Write(buf, binary.LittleEndian, defaultNull)
+ return true
+ }
+ case sql.NullFloat64:
+ if !val.Valid {
+ binary.Write(buf, binary.LittleEndian, defaultNull)
+ return true
+ }
+ case sql.NullInt64:
+ if !val.Valid {
+ binary.Write(buf, binary.LittleEndian, defaultNull)
+ return true
+ }
+ case sql.NullString:
+ if !val.Valid {
+ binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL))
+ return true
+ }
+ }
+ return false
+}