summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/tds.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/tds.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/tds.go201
1 files changed, 107 insertions, 94 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go
index 54ac6dbad5..a924e90109 100644
--- a/vendor/github.com/denisenkom/go-mssqldb/tds.go
+++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go
@@ -50,16 +50,16 @@ func parseInstances(msg []byte) map[string]map[string]string {
return results
}
-func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
- dialer := &net.Dialer{
- Timeout: 5 * time.Second,
- }
- conn, err := dialer.DialContext(ctx, "udp", address+":1434")
+func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
+ maxTime := 5 * time.Second
+ ctx, cancel := context.WithTimeout(ctx, maxTime)
+ defer cancel()
+ conn, err := d.DialContext(ctx, "udp", address+":1434")
if err != nil {
return nil, err
}
defer conn.Close()
- conn.SetDeadline(time.Now().Add(5 * time.Second))
+ conn.SetDeadline(time.Now().Add(maxTime))
_, err = conn.Write([]byte{3})
if err != nil {
return nil, err
@@ -152,19 +152,19 @@ type columnStruct struct {
ti typeInfo
}
-type KeySlice []uint8
+type keySlice []uint8
-func (p KeySlice) Len() int { return len(p) }
-func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] }
-func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+func (p keySlice) Len() int { return len(p) }
+func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
+func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
var err error
- w.BeginPacket(packPrelogin)
+ w.BeginPacket(packPrelogin, false)
offset := uint16(5*len(fields) + 1)
- keys := make(KeySlice, 0, len(fields))
+ keys := make(keySlice, 0, len(fields))
for k, _ := range fields {
keys = append(keys, k)
}
@@ -352,7 +352,7 @@ func manglePassword(password string) []byte {
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
func sendLogin(w *tdsBuffer, login login) error {
- w.BeginPacket(packLogin7)
+ w.BeginPacket(packLogin7, false)
hostname := str2ucs2(login.HostName)
username := str2ucs2(login.UserName)
password := manglePassword(login.Password)
@@ -633,8 +633,8 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
return nil
}
-func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
- buf.BeginPacket(packSQLBatch)
+func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
+ buf.BeginPacket(packSQLBatch, resetSession)
if err = writeAllHeaders(buf, headers); err != nil {
return
@@ -650,33 +650,34 @@ func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
func sendAttention(buf *tdsBuffer) error {
- buf.BeginPacket(packAttention)
+ buf.BeginPacket(packAttention, false)
return buf.FinishPacket()
}
type connectParams struct {
- logFlags uint64
- port uint64
- host string
- instance string
- database string
- user string
- password string
- dial_timeout time.Duration
- conn_timeout time.Duration
- keepAlive time.Duration
- encrypt bool
- disableEncryption bool
- trustServerCertificate bool
- certificate string
- hostInCertificate string
- serverSPN string
- workstation string
- appname string
- typeFlags uint8
- failOverPartner string
- failOverPort uint64
- packetSize uint16
+ logFlags uint64
+ port uint64
+ host string
+ instance string
+ database string
+ user string
+ password string
+ dial_timeout time.Duration
+ conn_timeout time.Duration
+ keepAlive time.Duration
+ encrypt bool
+ disableEncryption bool
+ trustServerCertificate bool
+ certificate string
+ hostInCertificate string
+ hostInCertificateProvided bool
+ serverSPN string
+ workstation string
+ appname string
+ typeFlags uint8
+ failOverPartner string
+ failOverPort uint64
+ packetSize uint16
}
func splitConnectionString(dsn string) (res map[string]string) {
@@ -938,13 +939,13 @@ func parseConnectParams(dsn string) (connectParams, error) {
strlog, ok := params["log"]
if ok {
var err error
- p.logFlags, err = strconv.ParseUint(strlog, 10, 0)
+ p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
if err != nil {
return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
}
}
server := params["server"]
- parts := strings.SplitN(server, "\\", 2)
+ parts := strings.SplitN(server, `\`, 2)
p.host = parts[0]
if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
p.host = "localhost"
@@ -960,7 +961,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
strport, ok := params["port"]
if ok {
var err error
- p.port, err = strconv.ParseUint(strport, 0, 16)
+ p.port, err = strconv.ParseUint(strport, 10, 16)
if err != nil {
f := "Invalid tcp port '%v': %v"
return p, fmt.Errorf(f, strport, err.Error())
@@ -992,20 +993,20 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
- p.dial_timeout = 15 * time.Second
- p.conn_timeout = 30 * time.Second
- strconntimeout, ok := params["connection timeout"]
- if ok {
- timeout, err := strconv.ParseUint(strconntimeout, 0, 16)
+ //
+ // Do not set a connection timeout. Use Context to manage such things.
+ // Default to zero, but still allow it to be set.
+ if strconntimeout, ok := params["connection timeout"]; ok {
+ timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
if err != nil {
f := "Invalid connection timeout '%v': %v"
return p, fmt.Errorf(f, strconntimeout, err.Error())
}
p.conn_timeout = time.Duration(timeout) * time.Second
}
- strdialtimeout, ok := params["dial timeout"]
- if ok {
- timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)
+ p.dial_timeout = 15 * time.Second
+ if strdialtimeout, ok := params["dial timeout"]; ok {
+ timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "Invalid dial timeout '%v': %v"
return p, fmt.Errorf(f, strdialtimeout, err.Error())
@@ -1016,9 +1017,8 @@ func parseConnectParams(dsn string) (connectParams, error) {
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.keepAlive = 30 * time.Second
-
if keepAlive, ok := params["keepalive"]; ok {
- timeout, err := strconv.ParseUint(keepAlive, 0, 16)
+ timeout, err := strconv.ParseUint(keepAlive, 10, 64)
if err != nil {
f := "Invalid keepAlive value '%s': %s"
return p, fmt.Errorf(f, keepAlive, err.Error())
@@ -1051,8 +1051,11 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
p.certificate = params["certificate"]
p.hostInCertificate, ok = params["hostnameincertificate"]
- if !ok {
+ if ok {
+ p.hostInCertificateProvided = true
+ } else {
p.hostInCertificate = p.host
+ p.hostInCertificateProvided = false
}
serverSPN, ok := params["serverspn"]
@@ -1112,7 +1115,7 @@ type auth interface {
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
// list of IP addresses. So if there is more than one, try them all and
// use the first one that allows a connection.
-func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
+func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
var ips []net.IP
ips, err = net.LookupIP(p.host)
if err != nil {
@@ -1123,9 +1126,9 @@ func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err er
ips = []net.IP{ip}
}
if len(ips) == 1 {
- d := createDialer(&p)
+ d := c.getDialer(&p)
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
- conn, err = d.Dial(ctx, addr)
+ conn, err = d.DialContext(ctx, "tcp", addr)
} else {
//Try Dials in parallel to avoid waiting for timeouts.
@@ -1134,9 +1137,9 @@ func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err er
portStr := strconv.Itoa(int(p.port))
for _, ip := range ips {
go func(ip net.IP) {
- d := createDialer(&p)
+ d := c.getDialer(&p)
addr := net.JoinHostPort(ip.String(), portStr)
- conn, err := d.Dial(ctx, addr)
+ conn, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
connChan <- conn
} else {
@@ -1174,12 +1177,18 @@ func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err er
return conn, err
}
-func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
- res = nil
+func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
+ dialCtx := ctx
+ if p.dial_timeout > 0 {
+ var cancel func()
+ dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
+ defer cancel()
+ }
// if instance is specified use instance resolution service
if p.instance != "" {
p.instance = strings.ToUpper(p.instance)
- instances, err := getInstances(ctx, p.host)
+ d := c.getDialer(&p)
+ instances, err := getInstances(dialCtx, d, p.host)
if err != nil {
f := "Unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.host, err.Error())
@@ -1197,12 +1206,12 @@ func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tds
}
initiate_connection:
- conn, err := dialConnection(ctx, p)
+ conn, err := dialConnection(dialCtx, c, p)
if err != nil {
return nil, err
}
- toconn := NewTimeoutConn(conn, p.conn_timeout)
+ toconn := newTimeoutConn(conn, p.conn_timeout)
outbuf := newTdsBuffer(p.packetSize, toconn)
sess := tdsSession{
@@ -1313,42 +1322,43 @@ initiate_connection:
}
// processing login response
- var sspi_msg []byte
-continue_login:
- tokchan := make(chan tokenStruct, 5)
- go processResponse(context.Background(), &sess, tokchan, nil)
success := false
- for tok := range tokchan {
- switch token := tok.(type) {
- case sspiMsg:
- sspi_msg, err = auth.NextBytes(token)
- if err != nil {
- return nil, err
- }
- case loginAckStruct:
- success = true
- sess.loginAck = token
- case error:
- return nil, fmt.Errorf("Login error: %s", token.Error())
- case doneStruct:
- if token.isError() {
- return nil, fmt.Errorf("Login error: %s", token.getError())
+ for {
+ tokchan := make(chan tokenStruct, 5)
+ go processResponse(context.Background(), &sess, tokchan, nil)
+ for tok := range tokchan {
+ switch token := tok.(type) {
+ case sspiMsg:
+ sspi_msg, err := auth.NextBytes(token)
+ if err != nil {
+ return nil, err
+ }
+ if sspi_msg != nil && len(sspi_msg) > 0 {
+ outbuf.BeginPacket(packSSPIMessage, false)
+ _, err = outbuf.Write(sspi_msg)
+ if err != nil {
+ return nil, err
+ }
+ err = outbuf.FinishPacket()
+ if err != nil {
+ return nil, err
+ }
+ sspi_msg = nil
+ }
+ case loginAckStruct:
+ success = true
+ sess.loginAck = token
+ case error:
+ return nil, fmt.Errorf("Login error: %s", token.Error())
+ case doneStruct:
+ if token.isError() {
+ return nil, fmt.Errorf("Login error: %s", token.getError())
+ }
+ goto loginEnd
}
}
}
- if sspi_msg != nil {
- outbuf.BeginPacket(packSSPIMessage)
- _, err = outbuf.Write(sspi_msg)
- if err != nil {
- return nil, err
- }
- err = outbuf.FinishPacket()
- if err != nil {
- return nil, err
- }
- sspi_msg = nil
- goto continue_login
- }
+loginEnd:
if !success {
return nil, fmt.Errorf("Login failed")
}
@@ -1356,6 +1366,9 @@ continue_login:
toconn.Close()
p.host = sess.routedServer
p.port = uint64(sess.routedPort)
+ if !p.hostInCertificateProvided {
+ p.hostInCertificate = sess.routedServer
+ }
goto initiate_connection
}
return &sess, nil