summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/tds.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/tds.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/tds.go1070
1 files changed, 1070 insertions, 0 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go
new file mode 100644
index 0000000000..fd42dba34a
--- /dev/null
+++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go
@@ -0,0 +1,1070 @@
+package mssql
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "os"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+ "unicode/utf16"
+ "unicode/utf8"
+)
+
+func parseInstances(msg []byte) map[string]map[string]string {
+ results := map[string]map[string]string{}
+ if len(msg) > 3 && msg[0] == 5 {
+ out_s := string(msg[3:])
+ tokens := strings.Split(out_s, ";")
+ instdict := map[string]string{}
+ got_name := false
+ var name string
+ for _, token := range tokens {
+ if got_name {
+ instdict[name] = token
+ got_name = false
+ } else {
+ name = token
+ if len(name) == 0 {
+ if len(instdict) == 0 {
+ break
+ }
+ results[strings.ToUpper(instdict["InstanceName"])] = instdict
+ instdict = map[string]string{}
+ continue
+ }
+ got_name = true
+ }
+ }
+ }
+ return results
+}
+
+func getInstances(address string) (map[string]map[string]string, error) {
+ conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+ conn.SetDeadline(time.Now().Add(5 * time.Second))
+ _, err = conn.Write([]byte{3})
+ if err != nil {
+ return nil, err
+ }
+ var resp = make([]byte, 16*1024-1)
+ read, err := conn.Read(resp)
+ if err != nil {
+ return nil, err
+ }
+ return parseInstances(resp[:read]), nil
+}
+
+// tds versions
+const (
+ verTDS70 = 0x70000000
+ verTDS71 = 0x71000000
+ verTDS71rev1 = 0x71000001
+ verTDS72 = 0x72090002
+ verTDS73A = 0x730A0003
+ verTDS73 = verTDS73A
+ verTDS73B = 0x730B0003
+ verTDS74 = 0x74000004
+)
+
+// packet types
+const (
+ packSQLBatch = 1
+ packRPCRequest = 3
+ packReply = 4
+ packCancel = 6
+ packBulkLoadBCP = 7
+ packTransMgrReq = 14
+ packNormal = 15
+ packLogin7 = 16
+ packSSPIMessage = 17
+ packPrelogin = 18
+)
+
+// prelogin fields
+// http://msdn.microsoft.com/en-us/library/dd357559.aspx
+const (
+ preloginVERSION = 0
+ preloginENCRYPTION = 1
+ preloginINSTOPT = 2
+ preloginTHREADID = 3
+ preloginMARS = 4
+ preloginTRACEID = 5
+ preloginTERMINATOR = 0xff
+)
+
+const (
+ encryptOff = 0 // Encryption is available but off.
+ encryptOn = 1 // Encryption is available and on.
+ encryptNotSup = 2 // Encryption is not available.
+ encryptReq = 3 // Encryption is required.
+)
+
+type tdsSession struct {
+ buf *tdsBuffer
+ loginAck loginAckStruct
+ database string
+ partner string
+ columns []columnStruct
+ tranid uint64
+ logFlags uint64
+ log *Logger
+ routedServer string
+ routedPort uint16
+}
+
+const (
+ logErrors = 1
+ logMessages = 2
+ logRows = 4
+ logSQL = 8
+ logParams = 16
+ logTransaction = 32
+)
+
+type columnStruct struct {
+ UserType uint32
+ Flags uint16
+ ColName string
+ ti typeInfo
+}
+
+type KeySlice []uint8
+
+func (p KeySlice) Len() int { return len(p) }
+func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] }
+func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+
+// http://msdn.microsoft.com/en-us/library/dd357559.aspx
+func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
+ var err error
+
+ w.BeginPacket(packPrelogin)
+ offset := uint16(5*len(fields) + 1)
+ keys := make(KeySlice, 0, len(fields))
+ for k, _ := range fields {
+ keys = append(keys, k)
+ }
+ sort.Sort(keys)
+ // writing header
+ for _, k := range keys {
+ err = w.WriteByte(k)
+ if err != nil {
+ return err
+ }
+ err = binary.Write(w, binary.BigEndian, offset)
+ if err != nil {
+ return err
+ }
+ v := fields[k]
+ size := uint16(len(v))
+ err = binary.Write(w, binary.BigEndian, size)
+ if err != nil {
+ return err
+ }
+ offset += size
+ }
+ err = w.WriteByte(preloginTERMINATOR)
+ if err != nil {
+ return err
+ }
+ // writing values
+ for _, k := range keys {
+ v := fields[k]
+ written, err := w.Write(v)
+ if err != nil {
+ return err
+ }
+ if written != len(v) {
+ return errors.New("Write method didn't write the whole value")
+ }
+ }
+ return w.FinishPacket()
+}
+
+func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
+ packet_type, err := r.BeginRead()
+ if err != nil {
+ return nil, err
+ }
+ struct_buf, err := ioutil.ReadAll(r)
+ if err != nil {
+ return nil, err
+ }
+ if packet_type != 4 {
+ return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
+ }
+ offset := 0
+ results := map[uint8][]byte{}
+ for true {
+ rec_type := struct_buf[offset]
+ if rec_type == preloginTERMINATOR {
+ break
+ }
+
+ rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
+ rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
+ value := struct_buf[rec_offset : rec_offset+rec_len]
+ results[rec_type] = value
+ offset += 5
+ }
+ return results, nil
+}
+
+// OptionFlags2
+// http://msdn.microsoft.com/en-us/library/dd304019.aspx
+const (
+ fLanguageFatal = 1
+ fODBC = 2
+ fTransBoundary = 4
+ fCacheConnect = 8
+ fIntSecurity = 0x80
+)
+
+// TypeFlags
+const (
+ // 4 bits for fSQLType
+ // 1 bit for fOLEDB
+ fReadOnlyIntent = 32
+)
+
+type login struct {
+ TDSVersion uint32
+ PacketSize uint32
+ ClientProgVer uint32
+ ClientPID uint32
+ ConnectionID uint32
+ OptionFlags1 uint8
+ OptionFlags2 uint8
+ TypeFlags uint8
+ OptionFlags3 uint8
+ ClientTimeZone int32
+ ClientLCID uint32
+ HostName string
+ UserName string
+ Password string
+ AppName string
+ ServerName string
+ CtlIntName string
+ Language string
+ Database string
+ ClientID [6]byte
+ SSPI []byte
+ AtchDBFile string
+ ChangePassword string
+}
+
+type loginHeader struct {
+ Length uint32
+ TDSVersion uint32
+ PacketSize uint32
+ ClientProgVer uint32
+ ClientPID uint32
+ ConnectionID uint32
+ OptionFlags1 uint8
+ OptionFlags2 uint8
+ TypeFlags uint8
+ OptionFlags3 uint8
+ ClientTimeZone int32
+ ClientLCID uint32
+ HostNameOffset uint16
+ HostNameLength uint16
+ UserNameOffset uint16
+ UserNameLength uint16
+ PasswordOffset uint16
+ PasswordLength uint16
+ AppNameOffset uint16
+ AppNameLength uint16
+ ServerNameOffset uint16
+ ServerNameLength uint16
+ ExtensionOffset uint16
+ ExtensionLenght uint16
+ CtlIntNameOffset uint16
+ CtlIntNameLength uint16
+ LanguageOffset uint16
+ LanguageLength uint16
+ DatabaseOffset uint16
+ DatabaseLength uint16
+ ClientID [6]byte
+ SSPIOffset uint16
+ SSPILength uint16
+ AtchDBFileOffset uint16
+ AtchDBFileLength uint16
+ ChangePasswordOffset uint16
+ ChangePasswordLength uint16
+ SSPILongLength uint32
+}
+
+// convert Go string to UTF-16 encoded []byte (littleEndian)
+// done manually rather than using bytes and binary packages
+// for performance reasons
+func str2ucs2(s string) []byte {
+ res := utf16.Encode([]rune(s))
+ ucs2 := make([]byte, 2*len(res))
+ for i := 0; i < len(res); i++ {
+ ucs2[2*i] = byte(res[i])
+ ucs2[2*i+1] = byte(res[i] >> 8)
+ }
+ return ucs2
+}
+
+func ucs22str(s []byte) (string, error) {
+ if len(s)%2 != 0 {
+ return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
+ }
+ buf := make([]uint16, len(s)/2)
+ for i := 0; i < len(s); i += 2 {
+ buf[i/2] = binary.LittleEndian.Uint16(s[i:])
+ }
+ return string(utf16.Decode(buf)), nil
+}
+
+func manglePassword(password string) []byte {
+ var ucs2password []byte = str2ucs2(password)
+ for i, ch := range ucs2password {
+ ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
+ }
+ return ucs2password
+}
+
+// http://msdn.microsoft.com/en-us/library/dd304019.aspx
+func sendLogin(w *tdsBuffer, login login) error {
+ w.BeginPacket(packLogin7)
+ hostname := str2ucs2(login.HostName)
+ username := str2ucs2(login.UserName)
+ password := manglePassword(login.Password)
+ appname := str2ucs2(login.AppName)
+ servername := str2ucs2(login.ServerName)
+ ctlintname := str2ucs2(login.CtlIntName)
+ language := str2ucs2(login.Language)
+ database := str2ucs2(login.Database)
+ atchdbfile := str2ucs2(login.AtchDBFile)
+ changepassword := str2ucs2(login.ChangePassword)
+ hdr := loginHeader{
+ TDSVersion: login.TDSVersion,
+ PacketSize: login.PacketSize,
+ ClientProgVer: login.ClientProgVer,
+ ClientPID: login.ClientPID,
+ ConnectionID: login.ConnectionID,
+ OptionFlags1: login.OptionFlags1,
+ OptionFlags2: login.OptionFlags2,
+ TypeFlags: login.TypeFlags,
+ OptionFlags3: login.OptionFlags3,
+ ClientTimeZone: login.ClientTimeZone,
+ ClientLCID: login.ClientLCID,
+ HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
+ UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
+ PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
+ AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
+ ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
+ CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
+ LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
+ DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
+ ClientID: login.ClientID,
+ SSPILength: uint16(len(login.SSPI)),
+ AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
+ ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
+ }
+ offset := uint16(binary.Size(hdr))
+ hdr.HostNameOffset = offset
+ offset += uint16(len(hostname))
+ hdr.UserNameOffset = offset
+ offset += uint16(len(username))
+ hdr.PasswordOffset = offset
+ offset += uint16(len(password))
+ hdr.AppNameOffset = offset
+ offset += uint16(len(appname))
+ hdr.ServerNameOffset = offset
+ offset += uint16(len(servername))
+ hdr.CtlIntNameOffset = offset
+ offset += uint16(len(ctlintname))
+ hdr.LanguageOffset = offset
+ offset += uint16(len(language))
+ hdr.DatabaseOffset = offset
+ offset += uint16(len(database))
+ hdr.SSPIOffset = offset
+ offset += uint16(len(login.SSPI))
+ hdr.AtchDBFileOffset = offset
+ offset += uint16(len(atchdbfile))
+ hdr.ChangePasswordOffset = offset
+ offset += uint16(len(changepassword))
+ hdr.Length = uint32(offset)
+ var err error
+ err = binary.Write(w, binary.LittleEndian, &hdr)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(hostname)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(username)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(password)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(appname)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(servername)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(ctlintname)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(language)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(database)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(login.SSPI)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(atchdbfile)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(changepassword)
+ if err != nil {
+ return err
+ }
+ return w.FinishPacket()
+}
+
+func readUcs2(r io.Reader, numchars int) (res string, err error) {
+ buf := make([]byte, numchars*2)
+ _, err = io.ReadFull(r, buf)
+ if err != nil {
+ return "", err
+ }
+ return ucs22str(buf)
+}
+
+func readUsVarChar(r io.Reader) (res string, err error) {
+ var numchars uint16
+ err = binary.Read(r, binary.LittleEndian, &numchars)
+ if err != nil {
+ return "", err
+ }
+ return readUcs2(r, int(numchars))
+}
+
+func writeUsVarChar(w io.Writer, s string) (err error) {
+ buf := str2ucs2(s)
+ var numchars int = len(buf) / 2
+ if numchars > 0xffff {
+ panic("invalid size for US_VARCHAR")
+ }
+ err = binary.Write(w, binary.LittleEndian, uint16(numchars))
+ if err != nil {
+ return
+ }
+ _, err = w.Write(buf)
+ return
+}
+
+func readBVarChar(r io.Reader) (res string, err error) {
+ var numchars uint8
+ err = binary.Read(r, binary.LittleEndian, &numchars)
+ if err != nil {
+ return "", err
+ }
+ return readUcs2(r, int(numchars))
+}
+
+func writeBVarChar(w io.Writer, s string) (err error) {
+ buf := str2ucs2(s)
+ var numchars int = len(buf) / 2
+ if numchars > 0xff {
+ panic("invalid size for B_VARCHAR")
+ }
+ err = binary.Write(w, binary.LittleEndian, uint8(numchars))
+ if err != nil {
+ return
+ }
+ _, err = w.Write(buf)
+ return
+}
+
+func readBVarByte(r io.Reader) (res []byte, err error) {
+ var length uint8
+ err = binary.Read(r, binary.LittleEndian, &length)
+ if err != nil {
+ return
+ }
+ res = make([]byte, length)
+ _, err = io.ReadFull(r, res)
+ return
+}
+
+func readUshort(r io.Reader) (res uint16, err error) {
+ err = binary.Read(r, binary.LittleEndian, &res)
+ return
+}
+
+func readByte(r io.Reader) (res byte, err error) {
+ var b [1]byte
+ _, err = r.Read(b[:])
+ res = b[0]
+ return
+}
+
+// Packet Data Stream Headers
+// http://msdn.microsoft.com/en-us/library/dd304953.aspx
+type headerStruct struct {
+ hdrtype uint16
+ data []byte
+}
+
+const (
+ dataStmHdrQueryNotif = 1 // query notifications
+ dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
+ dataStmHdrTraceActivity = 3
+)
+
+// Query Notifications Header
+// http://msdn.microsoft.com/en-us/library/dd304949.aspx
+type queryNotifHdr struct {
+ notifyId string
+ ssbDeployment string
+ notifyTimeout uint32
+}
+
+func (hdr queryNotifHdr) pack() (res []byte) {
+ notifyId := str2ucs2(hdr.notifyId)
+ ssbDeployment := str2ucs2(hdr.ssbDeployment)
+
+ res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
+ b := res
+
+ binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
+ b = b[2:]
+ copy(b, notifyId)
+ b = b[len(notifyId):]
+
+ binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
+ b = b[2:]
+ copy(b, ssbDeployment)
+ b = b[len(ssbDeployment):]
+
+ binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
+
+ return res
+}
+
+// MARS Transaction Descriptor Header
+// http://msdn.microsoft.com/en-us/library/dd340515.aspx
+type transDescrHdr struct {
+ transDescr uint64 // transaction descriptor returned from ENVCHANGE
+ outstandingReqCnt uint32 // outstanding request count
+}
+
+func (hdr transDescrHdr) pack() (res []byte) {
+ res = make([]byte, 8+4)
+ binary.LittleEndian.PutUint64(res, hdr.transDescr)
+ binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
+ return res
+}
+
+func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
+ // calculatint total length
+ var totallen uint32 = 4
+ for _, hdr := range headers {
+ totallen += 4 + 2 + uint32(len(hdr.data))
+ }
+ // writing
+ err = binary.Write(w, binary.LittleEndian, totallen)
+ if err != nil {
+ return err
+ }
+ for _, hdr := range headers {
+ var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
+ err = binary.Write(w, binary.LittleEndian, headerlen)
+ if err != nil {
+ return err
+ }
+ err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(hdr.data)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func sendSqlBatch72(buf *tdsBuffer,
+ sqltext string,
+ headers []headerStruct) (err error) {
+ buf.BeginPacket(packSQLBatch)
+
+ if err = writeAllHeaders(buf, headers); err != nil {
+ return
+ }
+
+ _, err = buf.Write(str2ucs2(sqltext))
+ if err != nil {
+ return
+ }
+ return buf.FinishPacket()
+}
+
+type connectParams struct {
+ logFlags uint64
+ port uint64
+ host string
+ instance string
+ database string
+ user string
+ password string
+ dial_timeout time.Duration
+ conn_timeout time.Duration
+ keepAlive time.Duration
+ encrypt bool
+ disableEncryption bool
+ trustServerCertificate bool
+ certificate string
+ hostInCertificate string
+ serverSPN string
+ workstation string
+ appname string
+ typeFlags uint8
+ failOverPartner string
+ failOverPort uint64
+}
+
+func splitConnectionString(dsn string) (res map[string]string) {
+ res = map[string]string{}
+ parts := strings.Split(dsn, ";")
+ for _, part := range parts {
+ if len(part) == 0 {
+ continue
+ }
+ lst := strings.SplitN(part, "=", 2)
+ name := strings.TrimSpace(strings.ToLower(lst[0]))
+ if len(name) == 0 {
+ continue
+ }
+ var value string = ""
+ if len(lst) > 1 {
+ value = strings.TrimSpace(lst[1])
+ }
+ res[name] = value
+ }
+ return res
+}
+
+func parseConnectParams(dsn string) (connectParams, error) {
+ params := splitConnectionString(dsn)
+ var p connectParams
+ strlog, ok := params["log"]
+ if ok {
+ var err error
+ p.logFlags, err = strconv.ParseUint(strlog, 10, 0)
+ if err != nil {
+ return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
+ }
+ }
+ server := params["server"]
+ parts := strings.SplitN(server, "\\", 2)
+ p.host = parts[0]
+ if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
+ p.host = "localhost"
+ }
+ if len(parts) > 1 {
+ p.instance = parts[1]
+ }
+ p.database = params["database"]
+ p.user = params["user id"]
+ p.password = params["password"]
+
+ p.port = 1433
+ strport, ok := params["port"]
+ if ok {
+ var err error
+ p.port, err = strconv.ParseUint(strport, 0, 16)
+ if err != nil {
+ f := "Invalid tcp port '%v': %v"
+ return p, fmt.Errorf(f, strport, err.Error())
+ }
+ }
+
+ p.dial_timeout = 5 * time.Second
+ p.conn_timeout = 30 * time.Second
+ strconntimeout, ok := params["connection timeout"]
+ if ok {
+ timeout, err := strconv.ParseUint(strconntimeout, 0, 16)
+ if err != nil {
+ f := "Invalid connection timeout '%v': %v"
+ return p, fmt.Errorf(f, strconntimeout, err.Error())
+ }
+ p.conn_timeout = time.Duration(timeout) * time.Second
+ }
+ strdialtimeout, ok := params["dial timeout"]
+ if ok {
+ timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)
+ if err != nil {
+ f := "Invalid dial timeout '%v': %v"
+ return p, fmt.Errorf(f, strdialtimeout, err.Error())
+ }
+ p.dial_timeout = time.Duration(timeout) * time.Second
+ }
+ keepAlive, ok := params["keepalive"]
+ if ok {
+ timeout, err := strconv.ParseUint(keepAlive, 0, 16)
+ if err != nil {
+ f := "Invalid keepAlive value '%s': %s"
+ return p, fmt.Errorf(f, keepAlive, err.Error())
+ }
+ p.keepAlive = time.Duration(timeout) * time.Second
+ }
+ encrypt, ok := params["encrypt"]
+ if ok {
+ if strings.ToUpper(encrypt) == "DISABLE" {
+ p.disableEncryption = true
+ } else {
+ var err error
+ p.encrypt, err = strconv.ParseBool(encrypt)
+ if err != nil {
+ f := "Invalid encrypt '%s': %s"
+ return p, fmt.Errorf(f, encrypt, err.Error())
+ }
+ }
+ } else {
+ p.trustServerCertificate = true
+ }
+ trust, ok := params["trustservercertificate"]
+ if ok {
+ var err error
+ p.trustServerCertificate, err = strconv.ParseBool(trust)
+ if err != nil {
+ f := "Invalid trust server certificate '%s': %s"
+ return p, fmt.Errorf(f, trust, err.Error())
+ }
+ }
+ p.certificate = params["certificate"]
+ p.hostInCertificate, ok = params["hostnameincertificate"]
+ if !ok {
+ p.hostInCertificate = p.host
+ }
+
+ serverSPN, ok := params["serverspn"]
+ if ok {
+ p.serverSPN = serverSPN
+ } else {
+ p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
+ }
+
+ workstation, ok := params["workstation id"]
+ if ok {
+ p.workstation = workstation
+ } else {
+ workstation, err := os.Hostname()
+ if err == nil {
+ p.workstation = workstation
+ }
+ }
+
+ appname, ok := params["app name"]
+ if !ok {
+ appname = "go-mssqldb"
+ }
+ p.appname = appname
+
+ appintent, ok := params["applicationintent"]
+ if ok {
+ if appintent == "ReadOnly" {
+ p.typeFlags |= fReadOnlyIntent
+ }
+ }
+
+ failOverPartner, ok := params["failoverpartner"]
+ if ok {
+ p.failOverPartner = failOverPartner
+ }
+
+ failOverPort, ok := params["failoverport"]
+ if ok {
+ var err error
+ p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
+ if err != nil {
+ f := "Invalid tcp port '%v': %v"
+ return p, fmt.Errorf(f, failOverPort, err.Error())
+ }
+ }
+
+ return p, nil
+}
+
+type Auth interface {
+ InitialBytes() ([]byte, error)
+ NextBytes([]byte) ([]byte, error)
+ Free()
+}
+
+// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
+// list of IP addresses. So if there is more than one, try them all and
+// use the first one that allows a connection.
+func dialConnection(p connectParams) (conn net.Conn, err error) {
+ var ips []net.IP
+ ips, err = net.LookupIP(p.host)
+ if err != nil {
+ ip := net.ParseIP(p.host)
+ if ip == nil {
+ return nil, err
+ }
+ ips = []net.IP{ip}
+ }
+ if len(ips) == 1 {
+ d := createDialer(p)
+ addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
+ conn, err = d.Dial("tcp", addr)
+
+ } else {
+ //Try Dials in parallel to avoid waiting for timeouts.
+ connChan := make(chan net.Conn, len(ips))
+ errChan := make(chan error, len(ips))
+ portStr := strconv.Itoa(int(p.port))
+ for _, ip := range ips {
+ go func(ip net.IP) {
+ d := createDialer(p)
+ addr := net.JoinHostPort(ip.String(), portStr)
+ conn, err := d.Dial("tcp", addr)
+ if err == nil {
+ connChan <- conn
+ } else {
+ errChan <- err
+ }
+ }(ip)
+ }
+ // Wait for either the *first* successful connection, or all the errors
+ wait_loop:
+ for i, _ := range ips {
+ select {
+ case conn = <-connChan:
+ // Got a connection to use, close any others
+ go func(n int) {
+ for i := 0; i < n; i++ {
+ select {
+ case conn := <-connChan:
+ conn.Close()
+ case <-errChan:
+ }
+ }
+ }(len(ips) - i - 1)
+ // Remove any earlier errors we may have collected
+ err = nil
+ break wait_loop
+ case err = <-errChan:
+ }
+ }
+ }
+ // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
+ if conn == nil {
+ f := "Unable to open tcp connection with host '%v:%v': %v"
+ return nil, fmt.Errorf(f, p.host, p.port, err.Error())
+ }
+
+ return conn, err
+}
+
+func connect(p connectParams) (res *tdsSession, err error) {
+ res = nil
+ // if instance is specified use instance resolution service
+ if p.instance != "" {
+ p.instance = strings.ToUpper(p.instance)
+ instances, err := getInstances(p.host)
+ if err != nil {
+ f := "Unable to get instances from Sql Server Browser on host %v: %v"
+ return nil, fmt.Errorf(f, p.host, err.Error())
+ }
+ strport, ok := instances[p.instance]["tcp"]
+ if !ok {
+ f := "No instance matching '%v' returned from host '%v'"
+ return nil, fmt.Errorf(f, p.instance, p.host)
+ }
+ p.port, err = strconv.ParseUint(strport, 0, 16)
+ if err != nil {
+ f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
+ return nil, fmt.Errorf(f, strport, err.Error())
+ }
+ }
+
+initiate_connection:
+ conn, err := dialConnection(p)
+ if err != nil {
+ return nil, err
+ }
+
+ toconn := NewTimeoutConn(conn, p.conn_timeout)
+
+ outbuf := newTdsBuffer(4096, toconn)
+ sess := tdsSession{
+ buf: outbuf,
+ logFlags: p.logFlags,
+ }
+
+ instance_buf := []byte(p.instance)
+ instance_buf = append(instance_buf, 0) // zero terminate instance name
+ var encrypt byte
+ if p.disableEncryption {
+ encrypt = encryptNotSup
+ } else if p.encrypt {
+ encrypt = encryptOn
+ } else {
+ encrypt = encryptOff
+ }
+ fields := map[uint8][]byte{
+ preloginVERSION: {0, 0, 0, 0, 0, 0},
+ preloginENCRYPTION: {encrypt},
+ preloginINSTOPT: instance_buf,
+ preloginTHREADID: {0, 0, 0, 0},
+ preloginMARS: {0}, // MARS disabled
+ }
+
+ err = writePrelogin(outbuf, fields)
+ if err != nil {
+ return nil, err
+ }
+
+ fields, err = readPrelogin(outbuf)
+ if err != nil {
+ return nil, err
+ }
+
+ encryptBytes, ok := fields[preloginENCRYPTION]
+ if !ok {
+ return nil, fmt.Errorf("Encrypt negotiation failed")
+ }
+ encrypt = encryptBytes[0]
+ if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
+ return nil, fmt.Errorf("Server does not support encryption")
+ }
+
+ if encrypt != encryptNotSup {
+ var config tls.Config
+ if p.certificate != "" {
+ pem, err := ioutil.ReadFile(p.certificate)
+ if err != nil {
+ f := "Cannot read certificate '%s': %s"
+ return nil, fmt.Errorf(f, p.certificate, err.Error())
+ }
+ certs := x509.NewCertPool()
+ certs.AppendCertsFromPEM(pem)
+ config.RootCAs = certs
+ }
+ if p.trustServerCertificate {
+ config.InsecureSkipVerify = true
+ }
+ config.ServerName = p.hostInCertificate
+ outbuf.transport = conn
+ toconn.buf = outbuf
+ tlsConn := tls.Client(toconn, &config)
+ err = tlsConn.Handshake()
+ toconn.buf = nil
+ outbuf.transport = tlsConn
+ if err != nil {
+ f := "TLS Handshake failed: %s"
+ return nil, fmt.Errorf(f, err.Error())
+ }
+ if encrypt == encryptOff {
+ outbuf.afterFirst = func() {
+ outbuf.transport = toconn
+ }
+ }
+ }
+
+ login := login{
+ TDSVersion: verTDS74,
+ PacketSize: uint32(len(outbuf.buf)),
+ Database: p.database,
+ OptionFlags2: fODBC, // to get unlimited TEXTSIZE
+ HostName: p.workstation,
+ ServerName: p.host,
+ AppName: p.appname,
+ TypeFlags: p.typeFlags,
+ }
+ auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
+ if auth_ok {
+ login.SSPI, err = auth.InitialBytes()
+ if err != nil {
+ return nil, err
+ }
+ login.OptionFlags2 |= fIntSecurity
+ defer auth.Free()
+ } else {
+ login.UserName = p.user
+ login.Password = p.password
+ }
+ err = sendLogin(outbuf, login)
+ if err != nil {
+ return nil, err
+ }
+
+ // processing login response
+ var sspi_msg []byte
+continue_login:
+ tokchan := make(chan tokenStruct, 5)
+ go processResponse(&sess, tokchan)
+ success := false
+ for tok := range tokchan {
+ switch token := tok.(type) {
+ case sspiMsg:
+ sspi_msg, err = auth.NextBytes(token)
+ if err != nil {
+ return nil, err
+ }
+ case loginAckStruct:
+ success = true
+ sess.loginAck = token
+ case error:
+ return nil, fmt.Errorf("Login error: %s", token.Error())
+ }
+ }
+ if sspi_msg != nil {
+ outbuf.BeginPacket(packSSPIMessage)
+ _, err = outbuf.Write(sspi_msg)
+ if err != nil {
+ return nil, err
+ }
+ err = outbuf.FinishPacket()
+ if err != nil {
+ return nil, err
+ }
+ sspi_msg = nil
+ goto continue_login
+ }
+ if !success {
+ return nil, fmt.Errorf("Login failed")
+ }
+ if sess.routedServer != "" {
+ toconn.Close()
+ p.host = sess.routedServer
+ p.port = uint64(sess.routedPort)
+ goto initiate_connection
+ }
+ return &sess, nil
+}