diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/tds.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/tds.go | 201 |
1 files changed, 107 insertions, 94 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go index 54ac6dbad5..a924e90109 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tds.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go @@ -50,16 +50,16 @@ func parseInstances(msg []byte) map[string]map[string]string { return results } -func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) { - dialer := &net.Dialer{ - Timeout: 5 * time.Second, - } - conn, err := dialer.DialContext(ctx, "udp", address+":1434") +func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) { + maxTime := 5 * time.Second + ctx, cancel := context.WithTimeout(ctx, maxTime) + defer cancel() + conn, err := d.DialContext(ctx, "udp", address+":1434") if err != nil { return nil, err } defer conn.Close() - conn.SetDeadline(time.Now().Add(5 * time.Second)) + conn.SetDeadline(time.Now().Add(maxTime)) _, err = conn.Write([]byte{3}) if err != nil { return nil, err @@ -152,19 +152,19 @@ type columnStruct struct { ti typeInfo } -type KeySlice []uint8 +type keySlice []uint8 -func (p KeySlice) Len() int { return len(p) } -func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] } -func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p keySlice) Len() int { return len(p) } +func (p keySlice) Less(i, j int) bool { return p[i] < p[j] } +func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // http://msdn.microsoft.com/en-us/library/dd357559.aspx func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { var err error - w.BeginPacket(packPrelogin) + w.BeginPacket(packPrelogin, false) offset := uint16(5*len(fields) + 1) - keys := make(KeySlice, 0, len(fields)) + keys := make(keySlice, 0, len(fields)) for k, _ := range fields { keys = append(keys, k) } @@ -352,7 +352,7 @@ func manglePassword(password string) []byte { // http://msdn.microsoft.com/en-us/library/dd304019.aspx func sendLogin(w *tdsBuffer, login login) error { - w.BeginPacket(packLogin7) + w.BeginPacket(packLogin7, false) hostname := str2ucs2(login.HostName) username := str2ucs2(login.UserName) password := manglePassword(login.Password) @@ -633,8 +633,8 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) { return nil } -func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) { - buf.BeginPacket(packSQLBatch) +func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) { + buf.BeginPacket(packSQLBatch, resetSession) if err = writeAllHeaders(buf, headers); err != nil { return @@ -650,33 +650,34 @@ func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx func sendAttention(buf *tdsBuffer) error { - buf.BeginPacket(packAttention) + buf.BeginPacket(packAttention, false) return buf.FinishPacket() } type connectParams struct { - logFlags uint64 - port uint64 - host string - instance string - database string - user string - password string - dial_timeout time.Duration - conn_timeout time.Duration - keepAlive time.Duration - encrypt bool - disableEncryption bool - trustServerCertificate bool - certificate string - hostInCertificate string - serverSPN string - workstation string - appname string - typeFlags uint8 - failOverPartner string - failOverPort uint64 - packetSize uint16 + logFlags uint64 + port uint64 + host string + instance string + database string + user string + password string + dial_timeout time.Duration + conn_timeout time.Duration + keepAlive time.Duration + encrypt bool + disableEncryption bool + trustServerCertificate bool + certificate string + hostInCertificate string + hostInCertificateProvided bool + serverSPN string + workstation string + appname string + typeFlags uint8 + failOverPartner string + failOverPort uint64 + packetSize uint16 } func splitConnectionString(dsn string) (res map[string]string) { @@ -938,13 +939,13 @@ func parseConnectParams(dsn string) (connectParams, error) { strlog, ok := params["log"] if ok { var err error - p.logFlags, err = strconv.ParseUint(strlog, 10, 0) + p.logFlags, err = strconv.ParseUint(strlog, 10, 64) if err != nil { return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error()) } } server := params["server"] - parts := strings.SplitN(server, "\\", 2) + parts := strings.SplitN(server, `\`, 2) p.host = parts[0] if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" { p.host = "localhost" @@ -960,7 +961,7 @@ func parseConnectParams(dsn string) (connectParams, error) { strport, ok := params["port"] if ok { var err error - p.port, err = strconv.ParseUint(strport, 0, 16) + p.port, err = strconv.ParseUint(strport, 10, 16) if err != nil { f := "Invalid tcp port '%v': %v" return p, fmt.Errorf(f, strport, err.Error()) @@ -992,20 +993,20 @@ func parseConnectParams(dsn string) (connectParams, error) { } // https://msdn.microsoft.com/en-us/library/dd341108.aspx - p.dial_timeout = 15 * time.Second - p.conn_timeout = 30 * time.Second - strconntimeout, ok := params["connection timeout"] - if ok { - timeout, err := strconv.ParseUint(strconntimeout, 0, 16) + // + // Do not set a connection timeout. Use Context to manage such things. + // Default to zero, but still allow it to be set. + if strconntimeout, ok := params["connection timeout"]; ok { + timeout, err := strconv.ParseUint(strconntimeout, 10, 64) if err != nil { f := "Invalid connection timeout '%v': %v" return p, fmt.Errorf(f, strconntimeout, err.Error()) } p.conn_timeout = time.Duration(timeout) * time.Second } - strdialtimeout, ok := params["dial timeout"] - if ok { - timeout, err := strconv.ParseUint(strdialtimeout, 0, 16) + p.dial_timeout = 15 * time.Second + if strdialtimeout, ok := params["dial timeout"]; ok { + timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) if err != nil { f := "Invalid dial timeout '%v': %v" return p, fmt.Errorf(f, strdialtimeout, err.Error()) @@ -1016,9 +1017,8 @@ func parseConnectParams(dsn string) (connectParams, error) { // default keep alive should be 30 seconds according to spec: // https://msdn.microsoft.com/en-us/library/dd341108.aspx p.keepAlive = 30 * time.Second - if keepAlive, ok := params["keepalive"]; ok { - timeout, err := strconv.ParseUint(keepAlive, 0, 16) + timeout, err := strconv.ParseUint(keepAlive, 10, 64) if err != nil { f := "Invalid keepAlive value '%s': %s" return p, fmt.Errorf(f, keepAlive, err.Error()) @@ -1051,8 +1051,11 @@ func parseConnectParams(dsn string) (connectParams, error) { } p.certificate = params["certificate"] p.hostInCertificate, ok = params["hostnameincertificate"] - if !ok { + if ok { + p.hostInCertificateProvided = true + } else { p.hostInCertificate = p.host + p.hostInCertificateProvided = false } serverSPN, ok := params["serverspn"] @@ -1112,7 +1115,7 @@ type auth interface { // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a // list of IP addresses. So if there is more than one, try them all and // use the first one that allows a connection. -func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) { +func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) { var ips []net.IP ips, err = net.LookupIP(p.host) if err != nil { @@ -1123,9 +1126,9 @@ func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err er ips = []net.IP{ip} } if len(ips) == 1 { - d := createDialer(&p) + d := c.getDialer(&p) addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port))) - conn, err = d.Dial(ctx, addr) + conn, err = d.DialContext(ctx, "tcp", addr) } else { //Try Dials in parallel to avoid waiting for timeouts. @@ -1134,9 +1137,9 @@ func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err er portStr := strconv.Itoa(int(p.port)) for _, ip := range ips { go func(ip net.IP) { - d := createDialer(&p) + d := c.getDialer(&p) addr := net.JoinHostPort(ip.String(), portStr) - conn, err := d.Dial(ctx, addr) + conn, err := d.DialContext(ctx, "tcp", addr) if err == nil { connChan <- conn } else { @@ -1174,12 +1177,18 @@ func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err er return conn, err } -func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) { - res = nil +func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { + dialCtx := ctx + if p.dial_timeout > 0 { + var cancel func() + dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout) + defer cancel() + } // if instance is specified use instance resolution service if p.instance != "" { p.instance = strings.ToUpper(p.instance) - instances, err := getInstances(ctx, p.host) + d := c.getDialer(&p) + instances, err := getInstances(dialCtx, d, p.host) if err != nil { f := "Unable to get instances from Sql Server Browser on host %v: %v" return nil, fmt.Errorf(f, p.host, err.Error()) @@ -1197,12 +1206,12 @@ func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tds } initiate_connection: - conn, err := dialConnection(ctx, p) + conn, err := dialConnection(dialCtx, c, p) if err != nil { return nil, err } - toconn := NewTimeoutConn(conn, p.conn_timeout) + toconn := newTimeoutConn(conn, p.conn_timeout) outbuf := newTdsBuffer(p.packetSize, toconn) sess := tdsSession{ @@ -1313,42 +1322,43 @@ initiate_connection: } // processing login response - var sspi_msg []byte -continue_login: - tokchan := make(chan tokenStruct, 5) - go processResponse(context.Background(), &sess, tokchan, nil) success := false - for tok := range tokchan { - switch token := tok.(type) { - case sspiMsg: - sspi_msg, err = auth.NextBytes(token) - if err != nil { - return nil, err - } - case loginAckStruct: - success = true - sess.loginAck = token - case error: - return nil, fmt.Errorf("Login error: %s", token.Error()) - case doneStruct: - if token.isError() { - return nil, fmt.Errorf("Login error: %s", token.getError()) + for { + tokchan := make(chan tokenStruct, 5) + go processResponse(context.Background(), &sess, tokchan, nil) + for tok := range tokchan { + switch token := tok.(type) { + case sspiMsg: + sspi_msg, err := auth.NextBytes(token) + if err != nil { + return nil, err + } + if sspi_msg != nil && len(sspi_msg) > 0 { + outbuf.BeginPacket(packSSPIMessage, false) + _, err = outbuf.Write(sspi_msg) + if err != nil { + return nil, err + } + err = outbuf.FinishPacket() + if err != nil { + return nil, err + } + sspi_msg = nil + } + case loginAckStruct: + success = true + sess.loginAck = token + case error: + return nil, fmt.Errorf("Login error: %s", token.Error()) + case doneStruct: + if token.isError() { + return nil, fmt.Errorf("Login error: %s", token.getError()) + } + goto loginEnd } } } - if sspi_msg != nil { - outbuf.BeginPacket(packSSPIMessage) - _, err = outbuf.Write(sspi_msg) - if err != nil { - return nil, err - } - err = outbuf.FinishPacket() - if err != nil { - return nil, err - } - sspi_msg = nil - goto continue_login - } +loginEnd: if !success { return nil, fmt.Errorf("Login failed") } @@ -1356,6 +1366,9 @@ continue_login: toconn.Close() p.host = sess.routedServer p.port = uint64(sess.routedPort) + if !p.hostInCertificateProvided { + p.hostInCertificate = sess.routedServer + } goto initiate_connection } return &sess, nil |