diff options
author | Antoine GIRARD <sapk@users.noreply.github.com> | 2019-10-02 02:32:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-10-02 02:32:12 +0200 |
commit | 149758c912842bedda86b5087cffd59ce0682e58 (patch) | |
tree | d8064fede0150c4e9bedae985f400458489a9db6 /vendor/github.com/denisenkom/go-mssqldb/tds.go | |
parent | 3a7e3dbfb40b892bf2b90e3d6bf30a028eae478a (diff) | |
download | gitea-149758c912842bedda86b5087cffd59ce0682e58.tar.gz gitea-149758c912842bedda86b5087cffd59ce0682e58.zip |
Update to github.com/lafriks/xormstore@v1.3.0 (#8317)
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/tds.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/tds.go | 483 |
1 files changed, 11 insertions, 472 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go index a924e90109..5a9f53b705 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tds.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go @@ -10,13 +10,9 @@ import ( "io" "io/ioutil" "net" - "net/url" - "os" "sort" "strconv" "strings" - "time" - "unicode" "unicode/utf16" "unicode/utf8" ) @@ -51,15 +47,13 @@ func parseInstances(msg []byte) map[string]map[string]string { } func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) { - maxTime := 5 * time.Second - ctx, cancel := context.WithTimeout(ctx, maxTime) - defer cancel() conn, err := d.DialContext(ctx, "udp", address+":1434") if err != nil { return nil, err } defer conn.Close() - conn.SetDeadline(time.Now().Add(maxTime)) + deadline, _ := ctx.Deadline() + conn.SetDeadline(deadline) _, err = conn.Write([]byte{3}) if err != nil { return nil, err @@ -474,10 +468,9 @@ func readUcs2(r io.Reader, numchars int) (res string, err error) { } func readUsVarChar(r io.Reader) (res string, err error) { - var numchars uint16 - err = binary.Read(r, binary.LittleEndian, &numchars) + numchars, err := readUshort(r) if err != nil { - return "", err + return } return readUcs2(r, int(numchars)) } @@ -497,8 +490,7 @@ func writeUsVarChar(w io.Writer, s string) (err error) { } func readBVarChar(r io.Reader) (res string, err error) { - var numchars uint8 - err = binary.Read(r, binary.LittleEndian, &numchars) + numchars, err := readByte(r) if err != nil { return "", err } @@ -525,8 +517,7 @@ func writeBVarChar(w io.Writer, s string) (err error) { } func readBVarByte(r io.Reader) (res []byte, err error) { - var length uint8 - err = binary.Read(r, binary.LittleEndian, &length) + length, err := readByte(r) if err != nil { return } @@ -654,458 +645,6 @@ func sendAttention(buf *tdsBuffer) error { return buf.FinishPacket() } -type connectParams struct { - logFlags uint64 - port uint64 - host string - instance string - database string - user string - password string - dial_timeout time.Duration - conn_timeout time.Duration - keepAlive time.Duration - encrypt bool - disableEncryption bool - trustServerCertificate bool - certificate string - hostInCertificate string - hostInCertificateProvided bool - serverSPN string - workstation string - appname string - typeFlags uint8 - failOverPartner string - failOverPort uint64 - packetSize uint16 -} - -func splitConnectionString(dsn string) (res map[string]string) { - res = map[string]string{} - parts := strings.Split(dsn, ";") - for _, part := range parts { - if len(part) == 0 { - continue - } - lst := strings.SplitN(part, "=", 2) - name := strings.TrimSpace(strings.ToLower(lst[0])) - if len(name) == 0 { - continue - } - var value string = "" - if len(lst) > 1 { - value = strings.TrimSpace(lst[1]) - } - res[name] = value - } - return res -} - -// Splits a URL in the ODBC format -func splitConnectionStringOdbc(dsn string) (map[string]string, error) { - res := map[string]string{} - - type parserState int - const ( - // Before the start of a key - parserStateBeforeKey parserState = iota - - // Inside a key - parserStateKey - - // Beginning of a value. May be bare or braced - parserStateBeginValue - - // Inside a bare value - parserStateBareValue - - // Inside a braced value - parserStateBracedValue - - // A closing brace inside a braced value. - // May be the end of the value or an escaped closing brace, depending on the next character - parserStateBracedValueClosingBrace - - // After a value. Next character should be a semicolon or whitespace. - parserStateEndValue - ) - - var state = parserStateBeforeKey - - var key string - var value string - - for i, c := range dsn { - switch state { - case parserStateBeforeKey: - switch { - case c == '=': - return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i) - case !unicode.IsSpace(c) && c != ';': - state = parserStateKey - key += string(c) - } - - case parserStateKey: - switch c { - case '=': - key = normalizeOdbcKey(key) - if len(key) == 0 { - return res, fmt.Errorf("Unexpected end of key at index %d.", i) - } - - state = parserStateBeginValue - - case ';': - // Key without value - key = normalizeOdbcKey(key) - if len(key) == 0 { - return res, fmt.Errorf("Unexpected end of key at index %d.", i) - } - - res[key] = value - key = "" - value = "" - state = parserStateBeforeKey - - default: - key += string(c) - } - - case parserStateBeginValue: - switch { - case c == '{': - state = parserStateBracedValue - case c == ';': - // Empty value - res[key] = value - key = "" - state = parserStateBeforeKey - case unicode.IsSpace(c): - // Ignore whitespace - default: - state = parserStateBareValue - value += string(c) - } - - case parserStateBareValue: - if c == ';' { - res[key] = strings.TrimRightFunc(value, unicode.IsSpace) - key = "" - value = "" - state = parserStateBeforeKey - } else { - value += string(c) - } - - case parserStateBracedValue: - if c == '}' { - state = parserStateBracedValueClosingBrace - } else { - value += string(c) - } - - case parserStateBracedValueClosingBrace: - if c == '}' { - // Escaped closing brace - value += string(c) - state = parserStateBracedValue - continue - } - - // End of braced value - res[key] = value - key = "" - value = "" - - // This character is the first character past the end, - // so it needs to be parsed like the parserStateEndValue state. - state = parserStateEndValue - switch { - case c == ';': - state = parserStateBeforeKey - case unicode.IsSpace(c): - // Ignore whitespace - default: - return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) - } - - case parserStateEndValue: - switch { - case c == ';': - state = parserStateBeforeKey - case unicode.IsSpace(c): - // Ignore whitespace - default: - return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) - } - } - } - - switch state { - case parserStateBeforeKey: // Okay - case parserStateKey: // Unfinished key. Treat as key without value. - key = normalizeOdbcKey(key) - if len(key) == 0 { - return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn)) - } - res[key] = value - case parserStateBeginValue: // Empty value - res[key] = value - case parserStateBareValue: - res[key] = strings.TrimRightFunc(value, unicode.IsSpace) - case parserStateBracedValue: - return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn)) - case parserStateBracedValueClosingBrace: // End of braced value - res[key] = value - case parserStateEndValue: // Okay - } - - return res, nil -} - -// Normalizes the given string as an ODBC-format key -func normalizeOdbcKey(s string) string { - return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace)) -} - -// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value -func splitConnectionStringURL(dsn string) (map[string]string, error) { - res := map[string]string{} - - u, err := url.Parse(dsn) - if err != nil { - return res, err - } - - if u.Scheme != "sqlserver" { - return res, fmt.Errorf("scheme %s is not recognized", u.Scheme) - } - - if u.User != nil { - res["user id"] = u.User.Username() - p, exists := u.User.Password() - if exists { - res["password"] = p - } - } - - host, port, err := net.SplitHostPort(u.Host) - if err != nil { - host = u.Host - } - - if len(u.Path) > 0 { - res["server"] = host + "\\" + u.Path[1:] - } else { - res["server"] = host - } - - if len(port) > 0 { - res["port"] = port - } - - query := u.Query() - for k, v := range query { - if len(v) > 1 { - return res, fmt.Errorf("key %s provided more than once", k) - } - res[strings.ToLower(k)] = v[0] - } - - return res, nil -} - -func parseConnectParams(dsn string) (connectParams, error) { - var p connectParams - - var params map[string]string - if strings.HasPrefix(dsn, "odbc:") { - parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):]) - if err != nil { - return p, err - } - params = parameters - } else if strings.HasPrefix(dsn, "sqlserver://") { - parameters, err := splitConnectionStringURL(dsn) - if err != nil { - return p, err - } - params = parameters - } else { - params = splitConnectionString(dsn) - } - - strlog, ok := params["log"] - if ok { - var err error - p.logFlags, err = strconv.ParseUint(strlog, 10, 64) - if err != nil { - return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error()) - } - } - server := params["server"] - parts := strings.SplitN(server, `\`, 2) - p.host = parts[0] - if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" { - p.host = "localhost" - } - if len(parts) > 1 { - p.instance = parts[1] - } - p.database = params["database"] - p.user = params["user id"] - p.password = params["password"] - - p.port = 1433 - strport, ok := params["port"] - if ok { - var err error - p.port, err = strconv.ParseUint(strport, 10, 16) - if err != nil { - f := "Invalid tcp port '%v': %v" - return p, fmt.Errorf(f, strport, err.Error()) - } - } - - // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option - // Default packet size remains at 4096 bytes - p.packetSize = 4096 - strpsize, ok := params["packet size"] - if ok { - var err error - psize, err := strconv.ParseUint(strpsize, 0, 16) - if err != nil { - f := "Invalid packet size '%v': %v" - return p, fmt.Errorf(f, strpsize, err.Error()) - } - - // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes - // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request - // a higher packet size, the server will respond with an ENVCHANGE request to - // alter the packet size to 16383 bytes. - p.packetSize = uint16(psize) - if p.packetSize < 512 { - p.packetSize = 512 - } else if p.packetSize > 32767 { - p.packetSize = 32767 - } - } - - // https://msdn.microsoft.com/en-us/library/dd341108.aspx - // - // Do not set a connection timeout. Use Context to manage such things. - // Default to zero, but still allow it to be set. - if strconntimeout, ok := params["connection timeout"]; ok { - timeout, err := strconv.ParseUint(strconntimeout, 10, 64) - if err != nil { - f := "Invalid connection timeout '%v': %v" - return p, fmt.Errorf(f, strconntimeout, err.Error()) - } - p.conn_timeout = time.Duration(timeout) * time.Second - } - p.dial_timeout = 15 * time.Second - if strdialtimeout, ok := params["dial timeout"]; ok { - timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) - if err != nil { - f := "Invalid dial timeout '%v': %v" - return p, fmt.Errorf(f, strdialtimeout, err.Error()) - } - p.dial_timeout = time.Duration(timeout) * time.Second - } - - // default keep alive should be 30 seconds according to spec: - // https://msdn.microsoft.com/en-us/library/dd341108.aspx - p.keepAlive = 30 * time.Second - if keepAlive, ok := params["keepalive"]; ok { - timeout, err := strconv.ParseUint(keepAlive, 10, 64) - if err != nil { - f := "Invalid keepAlive value '%s': %s" - return p, fmt.Errorf(f, keepAlive, err.Error()) - } - p.keepAlive = time.Duration(timeout) * time.Second - } - encrypt, ok := params["encrypt"] - if ok { - if strings.EqualFold(encrypt, "DISABLE") { - p.disableEncryption = true - } else { - var err error - p.encrypt, err = strconv.ParseBool(encrypt) - if err != nil { - f := "Invalid encrypt '%s': %s" - return p, fmt.Errorf(f, encrypt, err.Error()) - } - } - } else { - p.trustServerCertificate = true - } - trust, ok := params["trustservercertificate"] - if ok { - var err error - p.trustServerCertificate, err = strconv.ParseBool(trust) - if err != nil { - f := "Invalid trust server certificate '%s': %s" - return p, fmt.Errorf(f, trust, err.Error()) - } - } - p.certificate = params["certificate"] - p.hostInCertificate, ok = params["hostnameincertificate"] - if ok { - p.hostInCertificateProvided = true - } else { - p.hostInCertificate = p.host - p.hostInCertificateProvided = false - } - - serverSPN, ok := params["serverspn"] - if ok { - p.serverSPN = serverSPN - } else { - p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port) - } - - workstation, ok := params["workstation id"] - if ok { - p.workstation = workstation - } else { - workstation, err := os.Hostname() - if err == nil { - p.workstation = workstation - } - } - - appname, ok := params["app name"] - if !ok { - appname = "go-mssqldb" - } - p.appname = appname - - appintent, ok := params["applicationintent"] - if ok { - if appintent == "ReadOnly" { - p.typeFlags |= fReadOnlyIntent - } - } - - failOverPartner, ok := params["failoverpartner"] - if ok { - p.failOverPartner = failOverPartner - } - - failOverPort, ok := params["failoverport"] - if ok { - var err error - p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16) - if err != nil { - f := "Invalid tcp port '%v': %v" - return p, fmt.Errorf(f, failOverPort, err.Error()) - } - } - - return p, nil -} - type auth interface { InitialBytes() ([]byte, error) NextBytes([]byte) ([]byte, error) @@ -1277,12 +816,12 @@ initiate_connection: // while SQL Server seems to expect one TCP segment per encrypted TDS package. // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package config.DynamicRecordSizingDisabled = true - outbuf.transport = conn - toconn.buf = outbuf - tlsConn := tls.Client(toconn, &config) + // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream + handshakeConn := tlsHandshakeConn{buf: outbuf} + passthrough := passthroughConn{c: &handshakeConn} + tlsConn := tls.Client(&passthrough, &config) err = tlsConn.Handshake() - - toconn.buf = nil + passthrough.c = toconn outbuf.transport = tlsConn if err != nil { return nil, fmt.Errorf("TLS Handshake failed: %v", err) |