summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/tds.go
diff options
context:
space:
mode:
authorAntoine GIRARD <sapk@users.noreply.github.com>2019-10-02 02:32:12 +0200
committerGitHub <noreply@github.com>2019-10-02 02:32:12 +0200
commit149758c912842bedda86b5087cffd59ce0682e58 (patch)
treed8064fede0150c4e9bedae985f400458489a9db6 /vendor/github.com/denisenkom/go-mssqldb/tds.go
parent3a7e3dbfb40b892bf2b90e3d6bf30a028eae478a (diff)
downloadgitea-149758c912842bedda86b5087cffd59ce0682e58.tar.gz
gitea-149758c912842bedda86b5087cffd59ce0682e58.zip
Update to github.com/lafriks/xormstore@v1.3.0 (#8317)
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/tds.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/tds.go483
1 files changed, 11 insertions, 472 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go
index a924e90109..5a9f53b705 100644
--- a/vendor/github.com/denisenkom/go-mssqldb/tds.go
+++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go
@@ -10,13 +10,9 @@ import (
"io"
"io/ioutil"
"net"
- "net/url"
- "os"
"sort"
"strconv"
"strings"
- "time"
- "unicode"
"unicode/utf16"
"unicode/utf8"
)
@@ -51,15 +47,13 @@ func parseInstances(msg []byte) map[string]map[string]string {
}
func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
- maxTime := 5 * time.Second
- ctx, cancel := context.WithTimeout(ctx, maxTime)
- defer cancel()
conn, err := d.DialContext(ctx, "udp", address+":1434")
if err != nil {
return nil, err
}
defer conn.Close()
- conn.SetDeadline(time.Now().Add(maxTime))
+ deadline, _ := ctx.Deadline()
+ conn.SetDeadline(deadline)
_, err = conn.Write([]byte{3})
if err != nil {
return nil, err
@@ -474,10 +468,9 @@ func readUcs2(r io.Reader, numchars int) (res string, err error) {
}
func readUsVarChar(r io.Reader) (res string, err error) {
- var numchars uint16
- err = binary.Read(r, binary.LittleEndian, &numchars)
+ numchars, err := readUshort(r)
if err != nil {
- return "", err
+ return
}
return readUcs2(r, int(numchars))
}
@@ -497,8 +490,7 @@ func writeUsVarChar(w io.Writer, s string) (err error) {
}
func readBVarChar(r io.Reader) (res string, err error) {
- var numchars uint8
- err = binary.Read(r, binary.LittleEndian, &numchars)
+ numchars, err := readByte(r)
if err != nil {
return "", err
}
@@ -525,8 +517,7 @@ func writeBVarChar(w io.Writer, s string) (err error) {
}
func readBVarByte(r io.Reader) (res []byte, err error) {
- var length uint8
- err = binary.Read(r, binary.LittleEndian, &length)
+ length, err := readByte(r)
if err != nil {
return
}
@@ -654,458 +645,6 @@ func sendAttention(buf *tdsBuffer) error {
return buf.FinishPacket()
}
-type connectParams struct {
- logFlags uint64
- port uint64
- host string
- instance string
- database string
- user string
- password string
- dial_timeout time.Duration
- conn_timeout time.Duration
- keepAlive time.Duration
- encrypt bool
- disableEncryption bool
- trustServerCertificate bool
- certificate string
- hostInCertificate string
- hostInCertificateProvided bool
- serverSPN string
- workstation string
- appname string
- typeFlags uint8
- failOverPartner string
- failOverPort uint64
- packetSize uint16
-}
-
-func splitConnectionString(dsn string) (res map[string]string) {
- res = map[string]string{}
- parts := strings.Split(dsn, ";")
- for _, part := range parts {
- if len(part) == 0 {
- continue
- }
- lst := strings.SplitN(part, "=", 2)
- name := strings.TrimSpace(strings.ToLower(lst[0]))
- if len(name) == 0 {
- continue
- }
- var value string = ""
- if len(lst) > 1 {
- value = strings.TrimSpace(lst[1])
- }
- res[name] = value
- }
- return res
-}
-
-// Splits a URL in the ODBC format
-func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
- res := map[string]string{}
-
- type parserState int
- const (
- // Before the start of a key
- parserStateBeforeKey parserState = iota
-
- // Inside a key
- parserStateKey
-
- // Beginning of a value. May be bare or braced
- parserStateBeginValue
-
- // Inside a bare value
- parserStateBareValue
-
- // Inside a braced value
- parserStateBracedValue
-
- // A closing brace inside a braced value.
- // May be the end of the value or an escaped closing brace, depending on the next character
- parserStateBracedValueClosingBrace
-
- // After a value. Next character should be a semicolon or whitespace.
- parserStateEndValue
- )
-
- var state = parserStateBeforeKey
-
- var key string
- var value string
-
- for i, c := range dsn {
- switch state {
- case parserStateBeforeKey:
- switch {
- case c == '=':
- return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
- case !unicode.IsSpace(c) && c != ';':
- state = parserStateKey
- key += string(c)
- }
-
- case parserStateKey:
- switch c {
- case '=':
- key = normalizeOdbcKey(key)
- if len(key) == 0 {
- return res, fmt.Errorf("Unexpected end of key at index %d.", i)
- }
-
- state = parserStateBeginValue
-
- case ';':
- // Key without value
- key = normalizeOdbcKey(key)
- if len(key) == 0 {
- return res, fmt.Errorf("Unexpected end of key at index %d.", i)
- }
-
- res[key] = value
- key = ""
- value = ""
- state = parserStateBeforeKey
-
- default:
- key += string(c)
- }
-
- case parserStateBeginValue:
- switch {
- case c == '{':
- state = parserStateBracedValue
- case c == ';':
- // Empty value
- res[key] = value
- key = ""
- state = parserStateBeforeKey
- case unicode.IsSpace(c):
- // Ignore whitespace
- default:
- state = parserStateBareValue
- value += string(c)
- }
-
- case parserStateBareValue:
- if c == ';' {
- res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
- key = ""
- value = ""
- state = parserStateBeforeKey
- } else {
- value += string(c)
- }
-
- case parserStateBracedValue:
- if c == '}' {
- state = parserStateBracedValueClosingBrace
- } else {
- value += string(c)
- }
-
- case parserStateBracedValueClosingBrace:
- if c == '}' {
- // Escaped closing brace
- value += string(c)
- state = parserStateBracedValue
- continue
- }
-
- // End of braced value
- res[key] = value
- key = ""
- value = ""
-
- // This character is the first character past the end,
- // so it needs to be parsed like the parserStateEndValue state.
- state = parserStateEndValue
- switch {
- case c == ';':
- state = parserStateBeforeKey
- case unicode.IsSpace(c):
- // Ignore whitespace
- default:
- return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
- }
-
- case parserStateEndValue:
- switch {
- case c == ';':
- state = parserStateBeforeKey
- case unicode.IsSpace(c):
- // Ignore whitespace
- default:
- return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
- }
- }
- }
-
- switch state {
- case parserStateBeforeKey: // Okay
- case parserStateKey: // Unfinished key. Treat as key without value.
- key = normalizeOdbcKey(key)
- if len(key) == 0 {
- return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
- }
- res[key] = value
- case parserStateBeginValue: // Empty value
- res[key] = value
- case parserStateBareValue:
- res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
- case parserStateBracedValue:
- return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
- case parserStateBracedValueClosingBrace: // End of braced value
- res[key] = value
- case parserStateEndValue: // Okay
- }
-
- return res, nil
-}
-
-// Normalizes the given string as an ODBC-format key
-func normalizeOdbcKey(s string) string {
- return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
-}
-
-// Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
-func splitConnectionStringURL(dsn string) (map[string]string, error) {
- res := map[string]string{}
-
- u, err := url.Parse(dsn)
- if err != nil {
- return res, err
- }
-
- if u.Scheme != "sqlserver" {
- return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
- }
-
- if u.User != nil {
- res["user id"] = u.User.Username()
- p, exists := u.User.Password()
- if exists {
- res["password"] = p
- }
- }
-
- host, port, err := net.SplitHostPort(u.Host)
- if err != nil {
- host = u.Host
- }
-
- if len(u.Path) > 0 {
- res["server"] = host + "\\" + u.Path[1:]
- } else {
- res["server"] = host
- }
-
- if len(port) > 0 {
- res["port"] = port
- }
-
- query := u.Query()
- for k, v := range query {
- if len(v) > 1 {
- return res, fmt.Errorf("key %s provided more than once", k)
- }
- res[strings.ToLower(k)] = v[0]
- }
-
- return res, nil
-}
-
-func parseConnectParams(dsn string) (connectParams, error) {
- var p connectParams
-
- var params map[string]string
- if strings.HasPrefix(dsn, "odbc:") {
- parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
- if err != nil {
- return p, err
- }
- params = parameters
- } else if strings.HasPrefix(dsn, "sqlserver://") {
- parameters, err := splitConnectionStringURL(dsn)
- if err != nil {
- return p, err
- }
- params = parameters
- } else {
- params = splitConnectionString(dsn)
- }
-
- strlog, ok := params["log"]
- if ok {
- var err error
- p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
- if err != nil {
- return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
- }
- }
- server := params["server"]
- parts := strings.SplitN(server, `\`, 2)
- p.host = parts[0]
- if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
- p.host = "localhost"
- }
- if len(parts) > 1 {
- p.instance = parts[1]
- }
- p.database = params["database"]
- p.user = params["user id"]
- p.password = params["password"]
-
- p.port = 1433
- strport, ok := params["port"]
- if ok {
- var err error
- p.port, err = strconv.ParseUint(strport, 10, 16)
- if err != nil {
- f := "Invalid tcp port '%v': %v"
- return p, fmt.Errorf(f, strport, err.Error())
- }
- }
-
- // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
- // Default packet size remains at 4096 bytes
- p.packetSize = 4096
- strpsize, ok := params["packet size"]
- if ok {
- var err error
- psize, err := strconv.ParseUint(strpsize, 0, 16)
- if err != nil {
- f := "Invalid packet size '%v': %v"
- return p, fmt.Errorf(f, strpsize, err.Error())
- }
-
- // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
- // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
- // a higher packet size, the server will respond with an ENVCHANGE request to
- // alter the packet size to 16383 bytes.
- p.packetSize = uint16(psize)
- if p.packetSize < 512 {
- p.packetSize = 512
- } else if p.packetSize > 32767 {
- p.packetSize = 32767
- }
- }
-
- // https://msdn.microsoft.com/en-us/library/dd341108.aspx
- //
- // Do not set a connection timeout. Use Context to manage such things.
- // Default to zero, but still allow it to be set.
- if strconntimeout, ok := params["connection timeout"]; ok {
- timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
- if err != nil {
- f := "Invalid connection timeout '%v': %v"
- return p, fmt.Errorf(f, strconntimeout, err.Error())
- }
- p.conn_timeout = time.Duration(timeout) * time.Second
- }
- p.dial_timeout = 15 * time.Second
- if strdialtimeout, ok := params["dial timeout"]; ok {
- timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
- if err != nil {
- f := "Invalid dial timeout '%v': %v"
- return p, fmt.Errorf(f, strdialtimeout, err.Error())
- }
- p.dial_timeout = time.Duration(timeout) * time.Second
- }
-
- // default keep alive should be 30 seconds according to spec:
- // https://msdn.microsoft.com/en-us/library/dd341108.aspx
- p.keepAlive = 30 * time.Second
- if keepAlive, ok := params["keepalive"]; ok {
- timeout, err := strconv.ParseUint(keepAlive, 10, 64)
- if err != nil {
- f := "Invalid keepAlive value '%s': %s"
- return p, fmt.Errorf(f, keepAlive, err.Error())
- }
- p.keepAlive = time.Duration(timeout) * time.Second
- }
- encrypt, ok := params["encrypt"]
- if ok {
- if strings.EqualFold(encrypt, "DISABLE") {
- p.disableEncryption = true
- } else {
- var err error
- p.encrypt, err = strconv.ParseBool(encrypt)
- if err != nil {
- f := "Invalid encrypt '%s': %s"
- return p, fmt.Errorf(f, encrypt, err.Error())
- }
- }
- } else {
- p.trustServerCertificate = true
- }
- trust, ok := params["trustservercertificate"]
- if ok {
- var err error
- p.trustServerCertificate, err = strconv.ParseBool(trust)
- if err != nil {
- f := "Invalid trust server certificate '%s': %s"
- return p, fmt.Errorf(f, trust, err.Error())
- }
- }
- p.certificate = params["certificate"]
- p.hostInCertificate, ok = params["hostnameincertificate"]
- if ok {
- p.hostInCertificateProvided = true
- } else {
- p.hostInCertificate = p.host
- p.hostInCertificateProvided = false
- }
-
- serverSPN, ok := params["serverspn"]
- if ok {
- p.serverSPN = serverSPN
- } else {
- p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
- }
-
- workstation, ok := params["workstation id"]
- if ok {
- p.workstation = workstation
- } else {
- workstation, err := os.Hostname()
- if err == nil {
- p.workstation = workstation
- }
- }
-
- appname, ok := params["app name"]
- if !ok {
- appname = "go-mssqldb"
- }
- p.appname = appname
-
- appintent, ok := params["applicationintent"]
- if ok {
- if appintent == "ReadOnly" {
- p.typeFlags |= fReadOnlyIntent
- }
- }
-
- failOverPartner, ok := params["failoverpartner"]
- if ok {
- p.failOverPartner = failOverPartner
- }
-
- failOverPort, ok := params["failoverport"]
- if ok {
- var err error
- p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
- if err != nil {
- f := "Invalid tcp port '%v': %v"
- return p, fmt.Errorf(f, failOverPort, err.Error())
- }
- }
-
- return p, nil
-}
-
type auth interface {
InitialBytes() ([]byte, error)
NextBytes([]byte) ([]byte, error)
@@ -1277,12 +816,12 @@ initiate_connection:
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
- outbuf.transport = conn
- toconn.buf = outbuf
- tlsConn := tls.Client(toconn, &config)
+ // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
+ handshakeConn := tlsHandshakeConn{buf: outbuf}
+ passthrough := passthroughConn{c: &handshakeConn}
+ tlsConn := tls.Client(&passthrough, &config)
err = tlsConn.Handshake()
-
- toconn.buf = nil
+ passthrough.c = toconn
outbuf.transport = tlsConn
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)