You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tds.go 25KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037
  1. package mssql
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "io/ioutil"
  11. "net"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "unicode/utf16"
  16. "unicode/utf8"
  17. )
  18. func parseInstances(msg []byte) map[string]map[string]string {
  19. results := map[string]map[string]string{}
  20. if len(msg) > 3 && msg[0] == 5 {
  21. out_s := string(msg[3:])
  22. tokens := strings.Split(out_s, ";")
  23. instdict := map[string]string{}
  24. got_name := false
  25. var name string
  26. for _, token := range tokens {
  27. if got_name {
  28. instdict[name] = token
  29. got_name = false
  30. } else {
  31. name = token
  32. if len(name) == 0 {
  33. if len(instdict) == 0 {
  34. break
  35. }
  36. results[strings.ToUpper(instdict["InstanceName"])] = instdict
  37. instdict = map[string]string{}
  38. continue
  39. }
  40. got_name = true
  41. }
  42. }
  43. }
  44. return results
  45. }
  46. func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
  47. conn, err := d.DialContext(ctx, "udp", address+":1434")
  48. if err != nil {
  49. return nil, err
  50. }
  51. defer conn.Close()
  52. deadline, _ := ctx.Deadline()
  53. conn.SetDeadline(deadline)
  54. _, err = conn.Write([]byte{3})
  55. if err != nil {
  56. return nil, err
  57. }
  58. var resp = make([]byte, 16*1024-1)
  59. read, err := conn.Read(resp)
  60. if err != nil {
  61. return nil, err
  62. }
  63. return parseInstances(resp[:read]), nil
  64. }
  65. // tds versions
  66. const (
  67. verTDS70 = 0x70000000
  68. verTDS71 = 0x71000000
  69. verTDS71rev1 = 0x71000001
  70. verTDS72 = 0x72090002
  71. verTDS73A = 0x730A0003
  72. verTDS73 = verTDS73A
  73. verTDS73B = 0x730B0003
  74. verTDS74 = 0x74000004
  75. )
  76. // packet types
  77. // https://msdn.microsoft.com/en-us/library/dd304214.aspx
  78. const (
  79. packSQLBatch packetType = 1
  80. packRPCRequest = 3
  81. packReply = 4
  82. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  83. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  84. packAttention = 6
  85. packBulkLoadBCP = 7
  86. packTransMgrReq = 14
  87. packNormal = 15
  88. packLogin7 = 16
  89. packSSPIMessage = 17
  90. packPrelogin = 18
  91. )
  92. // prelogin fields
  93. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  94. const (
  95. preloginVERSION = 0
  96. preloginENCRYPTION = 1
  97. preloginINSTOPT = 2
  98. preloginTHREADID = 3
  99. preloginMARS = 4
  100. preloginTRACEID = 5
  101. preloginFEDAUTHREQUIRED = 6
  102. preloginNONCEOPT = 7
  103. preloginTERMINATOR = 0xff
  104. )
  105. const (
  106. encryptOff = 0 // Encryption is available but off.
  107. encryptOn = 1 // Encryption is available and on.
  108. encryptNotSup = 2 // Encryption is not available.
  109. encryptReq = 3 // Encryption is required.
  110. )
  111. type tdsSession struct {
  112. buf *tdsBuffer
  113. loginAck loginAckStruct
  114. database string
  115. partner string
  116. columns []columnStruct
  117. tranid uint64
  118. logFlags uint64
  119. log optionalLogger
  120. routedServer string
  121. routedPort uint16
  122. }
  123. const (
  124. logErrors = 1
  125. logMessages = 2
  126. logRows = 4
  127. logSQL = 8
  128. logParams = 16
  129. logTransaction = 32
  130. logDebug = 64
  131. )
  132. type columnStruct struct {
  133. UserType uint32
  134. Flags uint16
  135. ColName string
  136. ti typeInfo
  137. }
  138. type keySlice []uint8
  139. func (p keySlice) Len() int { return len(p) }
  140. func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
  141. func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  142. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  143. func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
  144. var err error
  145. w.BeginPacket(packPrelogin, false)
  146. offset := uint16(5*len(fields) + 1)
  147. keys := make(keySlice, 0, len(fields))
  148. for k, _ := range fields {
  149. keys = append(keys, k)
  150. }
  151. sort.Sort(keys)
  152. // writing header
  153. for _, k := range keys {
  154. err = w.WriteByte(k)
  155. if err != nil {
  156. return err
  157. }
  158. err = binary.Write(w, binary.BigEndian, offset)
  159. if err != nil {
  160. return err
  161. }
  162. v := fields[k]
  163. size := uint16(len(v))
  164. err = binary.Write(w, binary.BigEndian, size)
  165. if err != nil {
  166. return err
  167. }
  168. offset += size
  169. }
  170. err = w.WriteByte(preloginTERMINATOR)
  171. if err != nil {
  172. return err
  173. }
  174. // writing values
  175. for _, k := range keys {
  176. v := fields[k]
  177. written, err := w.Write(v)
  178. if err != nil {
  179. return err
  180. }
  181. if written != len(v) {
  182. return errors.New("Write method didn't write the whole value")
  183. }
  184. }
  185. return w.FinishPacket()
  186. }
  187. func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
  188. packet_type, err := r.BeginRead()
  189. if err != nil {
  190. return nil, err
  191. }
  192. struct_buf, err := ioutil.ReadAll(r)
  193. if err != nil {
  194. return nil, err
  195. }
  196. if packet_type != 4 {
  197. return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
  198. }
  199. offset := 0
  200. results := map[uint8][]byte{}
  201. for true {
  202. rec_type := struct_buf[offset]
  203. if rec_type == preloginTERMINATOR {
  204. break
  205. }
  206. rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
  207. rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
  208. value := struct_buf[rec_offset : rec_offset+rec_len]
  209. results[rec_type] = value
  210. offset += 5
  211. }
  212. return results, nil
  213. }
  214. // OptionFlags2
  215. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  216. const (
  217. fLanguageFatal = 1
  218. fODBC = 2
  219. fTransBoundary = 4
  220. fCacheConnect = 8
  221. fIntSecurity = 0x80
  222. )
  223. // TypeFlags
  224. const (
  225. // 4 bits for fSQLType
  226. // 1 bit for fOLEDB
  227. fReadOnlyIntent = 32
  228. )
  229. // OptionFlags3
  230. // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
  231. const (
  232. fExtension = 0x10
  233. )
  234. type login struct {
  235. TDSVersion uint32
  236. PacketSize uint32
  237. ClientProgVer uint32
  238. ClientPID uint32
  239. ConnectionID uint32
  240. OptionFlags1 uint8
  241. OptionFlags2 uint8
  242. TypeFlags uint8
  243. OptionFlags3 uint8
  244. ClientTimeZone int32
  245. ClientLCID uint32
  246. HostName string
  247. UserName string
  248. Password string
  249. AppName string
  250. ServerName string
  251. CtlIntName string
  252. Language string
  253. Database string
  254. ClientID [6]byte
  255. SSPI []byte
  256. AtchDBFile string
  257. ChangePassword string
  258. FeatureExt featureExts
  259. }
  260. type featureExts struct {
  261. features map[byte]featureExt
  262. }
  263. type featureExt interface {
  264. featureID() byte
  265. toBytes() []byte
  266. }
  267. func (e *featureExts) Add(f featureExt) error {
  268. if f == nil {
  269. return nil
  270. }
  271. id := f.featureID()
  272. if _, exists := e.features[id]; exists {
  273. f := "Login error: Feature with ID '%v' is already present in FeatureExt block."
  274. return fmt.Errorf(f, id)
  275. }
  276. if e.features == nil {
  277. e.features = make(map[byte]featureExt)
  278. }
  279. e.features[id] = f
  280. return nil
  281. }
  282. func (e featureExts) toBytes() []byte {
  283. if len(e.features) == 0 {
  284. return nil
  285. }
  286. var d []byte
  287. for featureID, f := range e.features {
  288. featureData := f.toBytes()
  289. hdr := make([]byte, 5)
  290. hdr[0] = featureID // FedAuth feature extension BYTE
  291. binary.LittleEndian.PutUint32(hdr[1:], uint32(len(featureData))) // FeatureDataLen DWORD
  292. d = append(d, hdr...)
  293. d = append(d, featureData...) // FeatureData *BYTE
  294. }
  295. if d != nil {
  296. d = append(d, 0xff) // Terminator
  297. }
  298. return d
  299. }
  300. type featureExtFedAuthSTS struct {
  301. FedAuthEcho bool
  302. FedAuthToken string
  303. Nonce []byte
  304. }
  305. func (e *featureExtFedAuthSTS) featureID() byte {
  306. return 0x02
  307. }
  308. func (e *featureExtFedAuthSTS) toBytes() []byte {
  309. if e == nil {
  310. return nil
  311. }
  312. options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT
  313. if e.FedAuthEcho {
  314. options |= 1 // fFedAuthEcho
  315. }
  316. d := make([]byte, 5)
  317. d[0] = options
  318. // looks like string in
  319. // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
  320. tokenBytes := str2ucs2(e.FedAuthToken)
  321. binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
  322. d = append(d, tokenBytes...)
  323. if len(e.Nonce) == 32 {
  324. d = append(d, e.Nonce...)
  325. }
  326. return d
  327. }
  328. type loginHeader struct {
  329. Length uint32
  330. TDSVersion uint32
  331. PacketSize uint32
  332. ClientProgVer uint32
  333. ClientPID uint32
  334. ConnectionID uint32
  335. OptionFlags1 uint8
  336. OptionFlags2 uint8
  337. TypeFlags uint8
  338. OptionFlags3 uint8
  339. ClientTimeZone int32
  340. ClientLCID uint32
  341. HostNameOffset uint16
  342. HostNameLength uint16
  343. UserNameOffset uint16
  344. UserNameLength uint16
  345. PasswordOffset uint16
  346. PasswordLength uint16
  347. AppNameOffset uint16
  348. AppNameLength uint16
  349. ServerNameOffset uint16
  350. ServerNameLength uint16
  351. ExtensionOffset uint16
  352. ExtensionLength uint16
  353. CtlIntNameOffset uint16
  354. CtlIntNameLength uint16
  355. LanguageOffset uint16
  356. LanguageLength uint16
  357. DatabaseOffset uint16
  358. DatabaseLength uint16
  359. ClientID [6]byte
  360. SSPIOffset uint16
  361. SSPILength uint16
  362. AtchDBFileOffset uint16
  363. AtchDBFileLength uint16
  364. ChangePasswordOffset uint16
  365. ChangePasswordLength uint16
  366. SSPILongLength uint32
  367. }
  368. // convert Go string to UTF-16 encoded []byte (littleEndian)
  369. // done manually rather than using bytes and binary packages
  370. // for performance reasons
  371. func str2ucs2(s string) []byte {
  372. res := utf16.Encode([]rune(s))
  373. ucs2 := make([]byte, 2*len(res))
  374. for i := 0; i < len(res); i++ {
  375. ucs2[2*i] = byte(res[i])
  376. ucs2[2*i+1] = byte(res[i] >> 8)
  377. }
  378. return ucs2
  379. }
  380. func ucs22str(s []byte) (string, error) {
  381. if len(s)%2 != 0 {
  382. return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
  383. }
  384. buf := make([]uint16, len(s)/2)
  385. for i := 0; i < len(s); i += 2 {
  386. buf[i/2] = binary.LittleEndian.Uint16(s[i:])
  387. }
  388. return string(utf16.Decode(buf)), nil
  389. }
  390. func manglePassword(password string) []byte {
  391. var ucs2password []byte = str2ucs2(password)
  392. for i, ch := range ucs2password {
  393. ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
  394. }
  395. return ucs2password
  396. }
  397. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  398. func sendLogin(w *tdsBuffer, login login) error {
  399. w.BeginPacket(packLogin7, false)
  400. hostname := str2ucs2(login.HostName)
  401. username := str2ucs2(login.UserName)
  402. password := manglePassword(login.Password)
  403. appname := str2ucs2(login.AppName)
  404. servername := str2ucs2(login.ServerName)
  405. ctlintname := str2ucs2(login.CtlIntName)
  406. language := str2ucs2(login.Language)
  407. database := str2ucs2(login.Database)
  408. atchdbfile := str2ucs2(login.AtchDBFile)
  409. changepassword := str2ucs2(login.ChangePassword)
  410. featureExt := login.FeatureExt.toBytes()
  411. hdr := loginHeader{
  412. TDSVersion: login.TDSVersion,
  413. PacketSize: login.PacketSize,
  414. ClientProgVer: login.ClientProgVer,
  415. ClientPID: login.ClientPID,
  416. ConnectionID: login.ConnectionID,
  417. OptionFlags1: login.OptionFlags1,
  418. OptionFlags2: login.OptionFlags2,
  419. TypeFlags: login.TypeFlags,
  420. OptionFlags3: login.OptionFlags3,
  421. ClientTimeZone: login.ClientTimeZone,
  422. ClientLCID: login.ClientLCID,
  423. HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
  424. UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
  425. PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
  426. AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
  427. ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
  428. CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
  429. LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
  430. DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
  431. ClientID: login.ClientID,
  432. SSPILength: uint16(len(login.SSPI)),
  433. AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
  434. ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
  435. }
  436. offset := uint16(binary.Size(hdr))
  437. hdr.HostNameOffset = offset
  438. offset += uint16(len(hostname))
  439. hdr.UserNameOffset = offset
  440. offset += uint16(len(username))
  441. hdr.PasswordOffset = offset
  442. offset += uint16(len(password))
  443. hdr.AppNameOffset = offset
  444. offset += uint16(len(appname))
  445. hdr.ServerNameOffset = offset
  446. offset += uint16(len(servername))
  447. hdr.CtlIntNameOffset = offset
  448. offset += uint16(len(ctlintname))
  449. hdr.LanguageOffset = offset
  450. offset += uint16(len(language))
  451. hdr.DatabaseOffset = offset
  452. offset += uint16(len(database))
  453. hdr.SSPIOffset = offset
  454. offset += uint16(len(login.SSPI))
  455. hdr.AtchDBFileOffset = offset
  456. offset += uint16(len(atchdbfile))
  457. hdr.ChangePasswordOffset = offset
  458. offset += uint16(len(changepassword))
  459. featureExtOffset := uint32(0)
  460. featureExtLen := len(featureExt)
  461. if featureExtLen > 0 {
  462. hdr.OptionFlags3 |= fExtension
  463. hdr.ExtensionOffset = offset
  464. hdr.ExtensionLength = 4
  465. offset += hdr.ExtensionLength // DWORD
  466. featureExtOffset = uint32(offset)
  467. }
  468. hdr.Length = uint32(offset) + uint32(featureExtLen)
  469. var err error
  470. err = binary.Write(w, binary.LittleEndian, &hdr)
  471. if err != nil {
  472. return err
  473. }
  474. _, err = w.Write(hostname)
  475. if err != nil {
  476. return err
  477. }
  478. _, err = w.Write(username)
  479. if err != nil {
  480. return err
  481. }
  482. _, err = w.Write(password)
  483. if err != nil {
  484. return err
  485. }
  486. _, err = w.Write(appname)
  487. if err != nil {
  488. return err
  489. }
  490. _, err = w.Write(servername)
  491. if err != nil {
  492. return err
  493. }
  494. _, err = w.Write(ctlintname)
  495. if err != nil {
  496. return err
  497. }
  498. _, err = w.Write(language)
  499. if err != nil {
  500. return err
  501. }
  502. _, err = w.Write(database)
  503. if err != nil {
  504. return err
  505. }
  506. _, err = w.Write(login.SSPI)
  507. if err != nil {
  508. return err
  509. }
  510. _, err = w.Write(atchdbfile)
  511. if err != nil {
  512. return err
  513. }
  514. _, err = w.Write(changepassword)
  515. if err != nil {
  516. return err
  517. }
  518. if featureExtOffset > 0 {
  519. err = binary.Write(w, binary.LittleEndian, featureExtOffset)
  520. if err != nil {
  521. return err
  522. }
  523. _, err = w.Write(featureExt)
  524. if err != nil {
  525. return err
  526. }
  527. }
  528. return w.FinishPacket()
  529. }
  530. func readUcs2(r io.Reader, numchars int) (res string, err error) {
  531. buf := make([]byte, numchars*2)
  532. _, err = io.ReadFull(r, buf)
  533. if err != nil {
  534. return "", err
  535. }
  536. return ucs22str(buf)
  537. }
  538. func readUsVarChar(r io.Reader) (res string, err error) {
  539. numchars, err := readUshort(r)
  540. if err != nil {
  541. return
  542. }
  543. return readUcs2(r, int(numchars))
  544. }
  545. func writeUsVarChar(w io.Writer, s string) (err error) {
  546. buf := str2ucs2(s)
  547. var numchars int = len(buf) / 2
  548. if numchars > 0xffff {
  549. panic("invalid size for US_VARCHAR")
  550. }
  551. err = binary.Write(w, binary.LittleEndian, uint16(numchars))
  552. if err != nil {
  553. return
  554. }
  555. _, err = w.Write(buf)
  556. return
  557. }
  558. func readBVarChar(r io.Reader) (res string, err error) {
  559. numchars, err := readByte(r)
  560. if err != nil {
  561. return "", err
  562. }
  563. // A zero length could be returned, return an empty string
  564. if numchars == 0 {
  565. return "", nil
  566. }
  567. return readUcs2(r, int(numchars))
  568. }
  569. func writeBVarChar(w io.Writer, s string) (err error) {
  570. buf := str2ucs2(s)
  571. var numchars int = len(buf) / 2
  572. if numchars > 0xff {
  573. panic("invalid size for B_VARCHAR")
  574. }
  575. err = binary.Write(w, binary.LittleEndian, uint8(numchars))
  576. if err != nil {
  577. return
  578. }
  579. _, err = w.Write(buf)
  580. return
  581. }
  582. func readBVarByte(r io.Reader) (res []byte, err error) {
  583. length, err := readByte(r)
  584. if err != nil {
  585. return
  586. }
  587. res = make([]byte, length)
  588. _, err = io.ReadFull(r, res)
  589. return
  590. }
  591. func readUshort(r io.Reader) (res uint16, err error) {
  592. err = binary.Read(r, binary.LittleEndian, &res)
  593. return
  594. }
  595. func readByte(r io.Reader) (res byte, err error) {
  596. var b [1]byte
  597. _, err = r.Read(b[:])
  598. res = b[0]
  599. return
  600. }
  601. // Packet Data Stream Headers
  602. // http://msdn.microsoft.com/en-us/library/dd304953.aspx
  603. type headerStruct struct {
  604. hdrtype uint16
  605. data []byte
  606. }
  607. const (
  608. dataStmHdrQueryNotif = 1 // query notifications
  609. dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
  610. dataStmHdrTraceActivity = 3
  611. )
  612. // Query Notifications Header
  613. // http://msdn.microsoft.com/en-us/library/dd304949.aspx
  614. type queryNotifHdr struct {
  615. notifyId string
  616. ssbDeployment string
  617. notifyTimeout uint32
  618. }
  619. func (hdr queryNotifHdr) pack() (res []byte) {
  620. notifyId := str2ucs2(hdr.notifyId)
  621. ssbDeployment := str2ucs2(hdr.ssbDeployment)
  622. res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
  623. b := res
  624. binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
  625. b = b[2:]
  626. copy(b, notifyId)
  627. b = b[len(notifyId):]
  628. binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
  629. b = b[2:]
  630. copy(b, ssbDeployment)
  631. b = b[len(ssbDeployment):]
  632. binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
  633. return res
  634. }
  635. // MARS Transaction Descriptor Header
  636. // http://msdn.microsoft.com/en-us/library/dd340515.aspx
  637. type transDescrHdr struct {
  638. transDescr uint64 // transaction descriptor returned from ENVCHANGE
  639. outstandingReqCnt uint32 // outstanding request count
  640. }
  641. func (hdr transDescrHdr) pack() (res []byte) {
  642. res = make([]byte, 8+4)
  643. binary.LittleEndian.PutUint64(res, hdr.transDescr)
  644. binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
  645. return res
  646. }
  647. func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
  648. // Calculating total length.
  649. var totallen uint32 = 4
  650. for _, hdr := range headers {
  651. totallen += 4 + 2 + uint32(len(hdr.data))
  652. }
  653. // writing
  654. err = binary.Write(w, binary.LittleEndian, totallen)
  655. if err != nil {
  656. return err
  657. }
  658. for _, hdr := range headers {
  659. var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
  660. err = binary.Write(w, binary.LittleEndian, headerlen)
  661. if err != nil {
  662. return err
  663. }
  664. err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
  665. if err != nil {
  666. return err
  667. }
  668. _, err = w.Write(hdr.data)
  669. if err != nil {
  670. return err
  671. }
  672. }
  673. return nil
  674. }
  675. func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
  676. buf.BeginPacket(packSQLBatch, resetSession)
  677. if err = writeAllHeaders(buf, headers); err != nil {
  678. return
  679. }
  680. _, err = buf.Write(str2ucs2(sqltext))
  681. if err != nil {
  682. return
  683. }
  684. return buf.FinishPacket()
  685. }
  686. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  687. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  688. func sendAttention(buf *tdsBuffer) error {
  689. buf.BeginPacket(packAttention, false)
  690. return buf.FinishPacket()
  691. }
  692. type auth interface {
  693. InitialBytes() ([]byte, error)
  694. NextBytes([]byte) ([]byte, error)
  695. Free()
  696. }
  697. // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
  698. // list of IP addresses. So if there is more than one, try them all and
  699. // use the first one that allows a connection.
  700. func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
  701. var ips []net.IP
  702. ips, err = net.LookupIP(p.host)
  703. if err != nil {
  704. ip := net.ParseIP(p.host)
  705. if ip == nil {
  706. return nil, err
  707. }
  708. ips = []net.IP{ip}
  709. }
  710. if len(ips) == 1 {
  711. d := c.getDialer(&p)
  712. addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.port))))
  713. conn, err = d.DialContext(ctx, "tcp", addr)
  714. } else {
  715. //Try Dials in parallel to avoid waiting for timeouts.
  716. connChan := make(chan net.Conn, len(ips))
  717. errChan := make(chan error, len(ips))
  718. portStr := strconv.Itoa(int(resolveServerPort(p.port)))
  719. for _, ip := range ips {
  720. go func(ip net.IP) {
  721. d := c.getDialer(&p)
  722. addr := net.JoinHostPort(ip.String(), portStr)
  723. conn, err := d.DialContext(ctx, "tcp", addr)
  724. if err == nil {
  725. connChan <- conn
  726. } else {
  727. errChan <- err
  728. }
  729. }(ip)
  730. }
  731. // Wait for either the *first* successful connection, or all the errors
  732. wait_loop:
  733. for i, _ := range ips {
  734. select {
  735. case conn = <-connChan:
  736. // Got a connection to use, close any others
  737. go func(n int) {
  738. for i := 0; i < n; i++ {
  739. select {
  740. case conn := <-connChan:
  741. conn.Close()
  742. case <-errChan:
  743. }
  744. }
  745. }(len(ips) - i - 1)
  746. // Remove any earlier errors we may have collected
  747. err = nil
  748. break wait_loop
  749. case err = <-errChan:
  750. }
  751. }
  752. }
  753. // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
  754. if conn == nil {
  755. f := "Unable to open tcp connection with host '%v:%v': %v"
  756. return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error())
  757. }
  758. return conn, err
  759. }
  760. func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
  761. dialCtx := ctx
  762. if p.dial_timeout > 0 {
  763. var cancel func()
  764. dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
  765. defer cancel()
  766. }
  767. // if instance is specified use instance resolution service
  768. if p.instance != "" && p.port == 0 {
  769. p.instance = strings.ToUpper(p.instance)
  770. d := c.getDialer(&p)
  771. instances, err := getInstances(dialCtx, d, p.host)
  772. if err != nil {
  773. f := "Unable to get instances from Sql Server Browser on host %v: %v"
  774. return nil, fmt.Errorf(f, p.host, err.Error())
  775. }
  776. strport, ok := instances[p.instance]["tcp"]
  777. if !ok {
  778. f := "No instance matching '%v' returned from host '%v'"
  779. return nil, fmt.Errorf(f, p.instance, p.host)
  780. }
  781. port, err := strconv.ParseUint(strport, 0, 16)
  782. if err != nil {
  783. f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
  784. return nil, fmt.Errorf(f, strport, err.Error())
  785. }
  786. p.port = port
  787. }
  788. initiate_connection:
  789. conn, err := dialConnection(dialCtx, c, p)
  790. if err != nil {
  791. return nil, err
  792. }
  793. toconn := newTimeoutConn(conn, p.conn_timeout)
  794. outbuf := newTdsBuffer(p.packetSize, toconn)
  795. sess := tdsSession{
  796. buf: outbuf,
  797. log: log,
  798. logFlags: p.logFlags,
  799. }
  800. instance_buf := []byte(p.instance)
  801. instance_buf = append(instance_buf, 0) // zero terminate instance name
  802. var encrypt byte
  803. if p.disableEncryption {
  804. encrypt = encryptNotSup
  805. } else if p.encrypt {
  806. encrypt = encryptOn
  807. } else {
  808. encrypt = encryptOff
  809. }
  810. fields := map[uint8][]byte{
  811. preloginVERSION: {0, 0, 0, 0, 0, 0},
  812. preloginENCRYPTION: {encrypt},
  813. preloginINSTOPT: instance_buf,
  814. preloginTHREADID: {0, 0, 0, 0},
  815. preloginMARS: {0}, // MARS disabled
  816. }
  817. err = writePrelogin(outbuf, fields)
  818. if err != nil {
  819. return nil, err
  820. }
  821. fields, err = readPrelogin(outbuf)
  822. if err != nil {
  823. return nil, err
  824. }
  825. encryptBytes, ok := fields[preloginENCRYPTION]
  826. if !ok {
  827. return nil, fmt.Errorf("Encrypt negotiation failed")
  828. }
  829. encrypt = encryptBytes[0]
  830. if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
  831. return nil, fmt.Errorf("Server does not support encryption")
  832. }
  833. if encrypt != encryptNotSup {
  834. var config tls.Config
  835. if p.certificate != "" {
  836. pem, err := ioutil.ReadFile(p.certificate)
  837. if err != nil {
  838. return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
  839. }
  840. certs := x509.NewCertPool()
  841. certs.AppendCertsFromPEM(pem)
  842. config.RootCAs = certs
  843. }
  844. if p.trustServerCertificate {
  845. config.InsecureSkipVerify = true
  846. }
  847. config.ServerName = p.hostInCertificate
  848. // fix for https://github.com/denisenkom/go-mssqldb/issues/166
  849. // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
  850. // while SQL Server seems to expect one TCP segment per encrypted TDS package.
  851. // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
  852. config.DynamicRecordSizingDisabled = true
  853. // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
  854. handshakeConn := tlsHandshakeConn{buf: outbuf}
  855. passthrough := passthroughConn{c: &handshakeConn}
  856. tlsConn := tls.Client(&passthrough, &config)
  857. err = tlsConn.Handshake()
  858. passthrough.c = toconn
  859. outbuf.transport = tlsConn
  860. if err != nil {
  861. return nil, fmt.Errorf("TLS Handshake failed: %v", err)
  862. }
  863. if encrypt == encryptOff {
  864. outbuf.afterFirst = func() {
  865. outbuf.transport = toconn
  866. }
  867. }
  868. }
  869. login := login{
  870. TDSVersion: verTDS74,
  871. PacketSize: uint32(outbuf.PackageSize()),
  872. Database: p.database,
  873. OptionFlags2: fODBC, // to get unlimited TEXTSIZE
  874. HostName: p.workstation,
  875. ServerName: p.host,
  876. AppName: p.appname,
  877. TypeFlags: p.typeFlags,
  878. }
  879. auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation)
  880. switch {
  881. case p.fedAuthAccessToken != "": // accesstoken ignores user/password
  882. featurext := &featureExtFedAuthSTS{
  883. FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1,
  884. FedAuthToken: p.fedAuthAccessToken,
  885. Nonce: fields[preloginNONCEOPT],
  886. }
  887. login.FeatureExt.Add(featurext)
  888. case authOk:
  889. login.SSPI, err = auth.InitialBytes()
  890. if err != nil {
  891. return nil, err
  892. }
  893. login.OptionFlags2 |= fIntSecurity
  894. defer auth.Free()
  895. default:
  896. login.UserName = p.user
  897. login.Password = p.password
  898. }
  899. err = sendLogin(outbuf, login)
  900. if err != nil {
  901. return nil, err
  902. }
  903. // processing login response
  904. success := false
  905. for {
  906. tokchan := make(chan tokenStruct, 5)
  907. go processResponse(context.Background(), &sess, tokchan, nil)
  908. for tok := range tokchan {
  909. switch token := tok.(type) {
  910. case sspiMsg:
  911. sspi_msg, err := auth.NextBytes(token)
  912. if err != nil {
  913. return nil, err
  914. }
  915. if sspi_msg != nil && len(sspi_msg) > 0 {
  916. outbuf.BeginPacket(packSSPIMessage, false)
  917. _, err = outbuf.Write(sspi_msg)
  918. if err != nil {
  919. return nil, err
  920. }
  921. err = outbuf.FinishPacket()
  922. if err != nil {
  923. return nil, err
  924. }
  925. sspi_msg = nil
  926. }
  927. case loginAckStruct:
  928. success = true
  929. sess.loginAck = token
  930. case error:
  931. return nil, fmt.Errorf("Login error: %s", token.Error())
  932. case doneStruct:
  933. if token.isError() {
  934. return nil, fmt.Errorf("Login error: %s", token.getError())
  935. }
  936. goto loginEnd
  937. }
  938. }
  939. }
  940. loginEnd:
  941. if !success {
  942. return nil, fmt.Errorf("Login failed")
  943. }
  944. if sess.routedServer != "" {
  945. toconn.Close()
  946. p.host = sess.routedServer
  947. p.port = uint64(sess.routedPort)
  948. if !p.hostInCertificateProvided {
  949. p.hostInCertificate = sess.routedServer
  950. }
  951. goto initiate_connection
  952. }
  953. return &sess, nil
  954. }