You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tds.go 34KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375
  1. package mssql
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "io/ioutil"
  11. "net"
  12. "net/url"
  13. "os"
  14. "sort"
  15. "strconv"
  16. "strings"
  17. "time"
  18. "unicode"
  19. "unicode/utf16"
  20. "unicode/utf8"
  21. )
  22. func parseInstances(msg []byte) map[string]map[string]string {
  23. results := map[string]map[string]string{}
  24. if len(msg) > 3 && msg[0] == 5 {
  25. out_s := string(msg[3:])
  26. tokens := strings.Split(out_s, ";")
  27. instdict := map[string]string{}
  28. got_name := false
  29. var name string
  30. for _, token := range tokens {
  31. if got_name {
  32. instdict[name] = token
  33. got_name = false
  34. } else {
  35. name = token
  36. if len(name) == 0 {
  37. if len(instdict) == 0 {
  38. break
  39. }
  40. results[strings.ToUpper(instdict["InstanceName"])] = instdict
  41. instdict = map[string]string{}
  42. continue
  43. }
  44. got_name = true
  45. }
  46. }
  47. }
  48. return results
  49. }
  50. func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
  51. maxTime := 5 * time.Second
  52. ctx, cancel := context.WithTimeout(ctx, maxTime)
  53. defer cancel()
  54. conn, err := d.DialContext(ctx, "udp", address+":1434")
  55. if err != nil {
  56. return nil, err
  57. }
  58. defer conn.Close()
  59. conn.SetDeadline(time.Now().Add(maxTime))
  60. _, err = conn.Write([]byte{3})
  61. if err != nil {
  62. return nil, err
  63. }
  64. var resp = make([]byte, 16*1024-1)
  65. read, err := conn.Read(resp)
  66. if err != nil {
  67. return nil, err
  68. }
  69. return parseInstances(resp[:read]), nil
  70. }
  71. // tds versions
  72. const (
  73. verTDS70 = 0x70000000
  74. verTDS71 = 0x71000000
  75. verTDS71rev1 = 0x71000001
  76. verTDS72 = 0x72090002
  77. verTDS73A = 0x730A0003
  78. verTDS73 = verTDS73A
  79. verTDS73B = 0x730B0003
  80. verTDS74 = 0x74000004
  81. )
  82. // packet types
  83. // https://msdn.microsoft.com/en-us/library/dd304214.aspx
  84. const (
  85. packSQLBatch packetType = 1
  86. packRPCRequest = 3
  87. packReply = 4
  88. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  89. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  90. packAttention = 6
  91. packBulkLoadBCP = 7
  92. packTransMgrReq = 14
  93. packNormal = 15
  94. packLogin7 = 16
  95. packSSPIMessage = 17
  96. packPrelogin = 18
  97. )
  98. // prelogin fields
  99. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  100. const (
  101. preloginVERSION = 0
  102. preloginENCRYPTION = 1
  103. preloginINSTOPT = 2
  104. preloginTHREADID = 3
  105. preloginMARS = 4
  106. preloginTRACEID = 5
  107. preloginTERMINATOR = 0xff
  108. )
  109. const (
  110. encryptOff = 0 // Encryption is available but off.
  111. encryptOn = 1 // Encryption is available and on.
  112. encryptNotSup = 2 // Encryption is not available.
  113. encryptReq = 3 // Encryption is required.
  114. )
  115. type tdsSession struct {
  116. buf *tdsBuffer
  117. loginAck loginAckStruct
  118. database string
  119. partner string
  120. columns []columnStruct
  121. tranid uint64
  122. logFlags uint64
  123. log optionalLogger
  124. routedServer string
  125. routedPort uint16
  126. }
  127. const (
  128. logErrors = 1
  129. logMessages = 2
  130. logRows = 4
  131. logSQL = 8
  132. logParams = 16
  133. logTransaction = 32
  134. logDebug = 64
  135. )
  136. type columnStruct struct {
  137. UserType uint32
  138. Flags uint16
  139. ColName string
  140. ti typeInfo
  141. }
  142. type keySlice []uint8
  143. func (p keySlice) Len() int { return len(p) }
  144. func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
  145. func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  146. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  147. func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
  148. var err error
  149. w.BeginPacket(packPrelogin, false)
  150. offset := uint16(5*len(fields) + 1)
  151. keys := make(keySlice, 0, len(fields))
  152. for k, _ := range fields {
  153. keys = append(keys, k)
  154. }
  155. sort.Sort(keys)
  156. // writing header
  157. for _, k := range keys {
  158. err = w.WriteByte(k)
  159. if err != nil {
  160. return err
  161. }
  162. err = binary.Write(w, binary.BigEndian, offset)
  163. if err != nil {
  164. return err
  165. }
  166. v := fields[k]
  167. size := uint16(len(v))
  168. err = binary.Write(w, binary.BigEndian, size)
  169. if err != nil {
  170. return err
  171. }
  172. offset += size
  173. }
  174. err = w.WriteByte(preloginTERMINATOR)
  175. if err != nil {
  176. return err
  177. }
  178. // writing values
  179. for _, k := range keys {
  180. v := fields[k]
  181. written, err := w.Write(v)
  182. if err != nil {
  183. return err
  184. }
  185. if written != len(v) {
  186. return errors.New("Write method didn't write the whole value")
  187. }
  188. }
  189. return w.FinishPacket()
  190. }
  191. func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
  192. packet_type, err := r.BeginRead()
  193. if err != nil {
  194. return nil, err
  195. }
  196. struct_buf, err := ioutil.ReadAll(r)
  197. if err != nil {
  198. return nil, err
  199. }
  200. if packet_type != 4 {
  201. return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
  202. }
  203. offset := 0
  204. results := map[uint8][]byte{}
  205. for true {
  206. rec_type := struct_buf[offset]
  207. if rec_type == preloginTERMINATOR {
  208. break
  209. }
  210. rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
  211. rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
  212. value := struct_buf[rec_offset : rec_offset+rec_len]
  213. results[rec_type] = value
  214. offset += 5
  215. }
  216. return results, nil
  217. }
  218. // OptionFlags2
  219. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  220. const (
  221. fLanguageFatal = 1
  222. fODBC = 2
  223. fTransBoundary = 4
  224. fCacheConnect = 8
  225. fIntSecurity = 0x80
  226. )
  227. // TypeFlags
  228. const (
  229. // 4 bits for fSQLType
  230. // 1 bit for fOLEDB
  231. fReadOnlyIntent = 32
  232. )
  233. type login struct {
  234. TDSVersion uint32
  235. PacketSize uint32
  236. ClientProgVer uint32
  237. ClientPID uint32
  238. ConnectionID uint32
  239. OptionFlags1 uint8
  240. OptionFlags2 uint8
  241. TypeFlags uint8
  242. OptionFlags3 uint8
  243. ClientTimeZone int32
  244. ClientLCID uint32
  245. HostName string
  246. UserName string
  247. Password string
  248. AppName string
  249. ServerName string
  250. CtlIntName string
  251. Language string
  252. Database string
  253. ClientID [6]byte
  254. SSPI []byte
  255. AtchDBFile string
  256. ChangePassword string
  257. }
  258. type loginHeader struct {
  259. Length uint32
  260. TDSVersion uint32
  261. PacketSize uint32
  262. ClientProgVer uint32
  263. ClientPID uint32
  264. ConnectionID uint32
  265. OptionFlags1 uint8
  266. OptionFlags2 uint8
  267. TypeFlags uint8
  268. OptionFlags3 uint8
  269. ClientTimeZone int32
  270. ClientLCID uint32
  271. HostNameOffset uint16
  272. HostNameLength uint16
  273. UserNameOffset uint16
  274. UserNameLength uint16
  275. PasswordOffset uint16
  276. PasswordLength uint16
  277. AppNameOffset uint16
  278. AppNameLength uint16
  279. ServerNameOffset uint16
  280. ServerNameLength uint16
  281. ExtensionOffset uint16
  282. ExtensionLenght uint16
  283. CtlIntNameOffset uint16
  284. CtlIntNameLength uint16
  285. LanguageOffset uint16
  286. LanguageLength uint16
  287. DatabaseOffset uint16
  288. DatabaseLength uint16
  289. ClientID [6]byte
  290. SSPIOffset uint16
  291. SSPILength uint16
  292. AtchDBFileOffset uint16
  293. AtchDBFileLength uint16
  294. ChangePasswordOffset uint16
  295. ChangePasswordLength uint16
  296. SSPILongLength uint32
  297. }
  298. // convert Go string to UTF-16 encoded []byte (littleEndian)
  299. // done manually rather than using bytes and binary packages
  300. // for performance reasons
  301. func str2ucs2(s string) []byte {
  302. res := utf16.Encode([]rune(s))
  303. ucs2 := make([]byte, 2*len(res))
  304. for i := 0; i < len(res); i++ {
  305. ucs2[2*i] = byte(res[i])
  306. ucs2[2*i+1] = byte(res[i] >> 8)
  307. }
  308. return ucs2
  309. }
  310. func ucs22str(s []byte) (string, error) {
  311. if len(s)%2 != 0 {
  312. return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
  313. }
  314. buf := make([]uint16, len(s)/2)
  315. for i := 0; i < len(s); i += 2 {
  316. buf[i/2] = binary.LittleEndian.Uint16(s[i:])
  317. }
  318. return string(utf16.Decode(buf)), nil
  319. }
  320. func manglePassword(password string) []byte {
  321. var ucs2password []byte = str2ucs2(password)
  322. for i, ch := range ucs2password {
  323. ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
  324. }
  325. return ucs2password
  326. }
  327. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  328. func sendLogin(w *tdsBuffer, login login) error {
  329. w.BeginPacket(packLogin7, false)
  330. hostname := str2ucs2(login.HostName)
  331. username := str2ucs2(login.UserName)
  332. password := manglePassword(login.Password)
  333. appname := str2ucs2(login.AppName)
  334. servername := str2ucs2(login.ServerName)
  335. ctlintname := str2ucs2(login.CtlIntName)
  336. language := str2ucs2(login.Language)
  337. database := str2ucs2(login.Database)
  338. atchdbfile := str2ucs2(login.AtchDBFile)
  339. changepassword := str2ucs2(login.ChangePassword)
  340. hdr := loginHeader{
  341. TDSVersion: login.TDSVersion,
  342. PacketSize: login.PacketSize,
  343. ClientProgVer: login.ClientProgVer,
  344. ClientPID: login.ClientPID,
  345. ConnectionID: login.ConnectionID,
  346. OptionFlags1: login.OptionFlags1,
  347. OptionFlags2: login.OptionFlags2,
  348. TypeFlags: login.TypeFlags,
  349. OptionFlags3: login.OptionFlags3,
  350. ClientTimeZone: login.ClientTimeZone,
  351. ClientLCID: login.ClientLCID,
  352. HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
  353. UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
  354. PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
  355. AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
  356. ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
  357. CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
  358. LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
  359. DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
  360. ClientID: login.ClientID,
  361. SSPILength: uint16(len(login.SSPI)),
  362. AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
  363. ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
  364. }
  365. offset := uint16(binary.Size(hdr))
  366. hdr.HostNameOffset = offset
  367. offset += uint16(len(hostname))
  368. hdr.UserNameOffset = offset
  369. offset += uint16(len(username))
  370. hdr.PasswordOffset = offset
  371. offset += uint16(len(password))
  372. hdr.AppNameOffset = offset
  373. offset += uint16(len(appname))
  374. hdr.ServerNameOffset = offset
  375. offset += uint16(len(servername))
  376. hdr.CtlIntNameOffset = offset
  377. offset += uint16(len(ctlintname))
  378. hdr.LanguageOffset = offset
  379. offset += uint16(len(language))
  380. hdr.DatabaseOffset = offset
  381. offset += uint16(len(database))
  382. hdr.SSPIOffset = offset
  383. offset += uint16(len(login.SSPI))
  384. hdr.AtchDBFileOffset = offset
  385. offset += uint16(len(atchdbfile))
  386. hdr.ChangePasswordOffset = offset
  387. offset += uint16(len(changepassword))
  388. hdr.Length = uint32(offset)
  389. var err error
  390. err = binary.Write(w, binary.LittleEndian, &hdr)
  391. if err != nil {
  392. return err
  393. }
  394. _, err = w.Write(hostname)
  395. if err != nil {
  396. return err
  397. }
  398. _, err = w.Write(username)
  399. if err != nil {
  400. return err
  401. }
  402. _, err = w.Write(password)
  403. if err != nil {
  404. return err
  405. }
  406. _, err = w.Write(appname)
  407. if err != nil {
  408. return err
  409. }
  410. _, err = w.Write(servername)
  411. if err != nil {
  412. return err
  413. }
  414. _, err = w.Write(ctlintname)
  415. if err != nil {
  416. return err
  417. }
  418. _, err = w.Write(language)
  419. if err != nil {
  420. return err
  421. }
  422. _, err = w.Write(database)
  423. if err != nil {
  424. return err
  425. }
  426. _, err = w.Write(login.SSPI)
  427. if err != nil {
  428. return err
  429. }
  430. _, err = w.Write(atchdbfile)
  431. if err != nil {
  432. return err
  433. }
  434. _, err = w.Write(changepassword)
  435. if err != nil {
  436. return err
  437. }
  438. return w.FinishPacket()
  439. }
  440. func readUcs2(r io.Reader, numchars int) (res string, err error) {
  441. buf := make([]byte, numchars*2)
  442. _, err = io.ReadFull(r, buf)
  443. if err != nil {
  444. return "", err
  445. }
  446. return ucs22str(buf)
  447. }
  448. func readUsVarChar(r io.Reader) (res string, err error) {
  449. var numchars uint16
  450. err = binary.Read(r, binary.LittleEndian, &numchars)
  451. if err != nil {
  452. return "", err
  453. }
  454. return readUcs2(r, int(numchars))
  455. }
  456. func writeUsVarChar(w io.Writer, s string) (err error) {
  457. buf := str2ucs2(s)
  458. var numchars int = len(buf) / 2
  459. if numchars > 0xffff {
  460. panic("invalid size for US_VARCHAR")
  461. }
  462. err = binary.Write(w, binary.LittleEndian, uint16(numchars))
  463. if err != nil {
  464. return
  465. }
  466. _, err = w.Write(buf)
  467. return
  468. }
  469. func readBVarChar(r io.Reader) (res string, err error) {
  470. var numchars uint8
  471. err = binary.Read(r, binary.LittleEndian, &numchars)
  472. if err != nil {
  473. return "", err
  474. }
  475. // A zero length could be returned, return an empty string
  476. if numchars == 0 {
  477. return "", nil
  478. }
  479. return readUcs2(r, int(numchars))
  480. }
  481. func writeBVarChar(w io.Writer, s string) (err error) {
  482. buf := str2ucs2(s)
  483. var numchars int = len(buf) / 2
  484. if numchars > 0xff {
  485. panic("invalid size for B_VARCHAR")
  486. }
  487. err = binary.Write(w, binary.LittleEndian, uint8(numchars))
  488. if err != nil {
  489. return
  490. }
  491. _, err = w.Write(buf)
  492. return
  493. }
  494. func readBVarByte(r io.Reader) (res []byte, err error) {
  495. var length uint8
  496. err = binary.Read(r, binary.LittleEndian, &length)
  497. if err != nil {
  498. return
  499. }
  500. res = make([]byte, length)
  501. _, err = io.ReadFull(r, res)
  502. return
  503. }
  504. func readUshort(r io.Reader) (res uint16, err error) {
  505. err = binary.Read(r, binary.LittleEndian, &res)
  506. return
  507. }
  508. func readByte(r io.Reader) (res byte, err error) {
  509. var b [1]byte
  510. _, err = r.Read(b[:])
  511. res = b[0]
  512. return
  513. }
  514. // Packet Data Stream Headers
  515. // http://msdn.microsoft.com/en-us/library/dd304953.aspx
  516. type headerStruct struct {
  517. hdrtype uint16
  518. data []byte
  519. }
  520. const (
  521. dataStmHdrQueryNotif = 1 // query notifications
  522. dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
  523. dataStmHdrTraceActivity = 3
  524. )
  525. // Query Notifications Header
  526. // http://msdn.microsoft.com/en-us/library/dd304949.aspx
  527. type queryNotifHdr struct {
  528. notifyId string
  529. ssbDeployment string
  530. notifyTimeout uint32
  531. }
  532. func (hdr queryNotifHdr) pack() (res []byte) {
  533. notifyId := str2ucs2(hdr.notifyId)
  534. ssbDeployment := str2ucs2(hdr.ssbDeployment)
  535. res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
  536. b := res
  537. binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
  538. b = b[2:]
  539. copy(b, notifyId)
  540. b = b[len(notifyId):]
  541. binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
  542. b = b[2:]
  543. copy(b, ssbDeployment)
  544. b = b[len(ssbDeployment):]
  545. binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
  546. return res
  547. }
  548. // MARS Transaction Descriptor Header
  549. // http://msdn.microsoft.com/en-us/library/dd340515.aspx
  550. type transDescrHdr struct {
  551. transDescr uint64 // transaction descriptor returned from ENVCHANGE
  552. outstandingReqCnt uint32 // outstanding request count
  553. }
  554. func (hdr transDescrHdr) pack() (res []byte) {
  555. res = make([]byte, 8+4)
  556. binary.LittleEndian.PutUint64(res, hdr.transDescr)
  557. binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
  558. return res
  559. }
  560. func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
  561. // Calculating total length.
  562. var totallen uint32 = 4
  563. for _, hdr := range headers {
  564. totallen += 4 + 2 + uint32(len(hdr.data))
  565. }
  566. // writing
  567. err = binary.Write(w, binary.LittleEndian, totallen)
  568. if err != nil {
  569. return err
  570. }
  571. for _, hdr := range headers {
  572. var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
  573. err = binary.Write(w, binary.LittleEndian, headerlen)
  574. if err != nil {
  575. return err
  576. }
  577. err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
  578. if err != nil {
  579. return err
  580. }
  581. _, err = w.Write(hdr.data)
  582. if err != nil {
  583. return err
  584. }
  585. }
  586. return nil
  587. }
  588. func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
  589. buf.BeginPacket(packSQLBatch, resetSession)
  590. if err = writeAllHeaders(buf, headers); err != nil {
  591. return
  592. }
  593. _, err = buf.Write(str2ucs2(sqltext))
  594. if err != nil {
  595. return
  596. }
  597. return buf.FinishPacket()
  598. }
  599. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  600. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  601. func sendAttention(buf *tdsBuffer) error {
  602. buf.BeginPacket(packAttention, false)
  603. return buf.FinishPacket()
  604. }
  605. type connectParams struct {
  606. logFlags uint64
  607. port uint64
  608. host string
  609. instance string
  610. database string
  611. user string
  612. password string
  613. dial_timeout time.Duration
  614. conn_timeout time.Duration
  615. keepAlive time.Duration
  616. encrypt bool
  617. disableEncryption bool
  618. trustServerCertificate bool
  619. certificate string
  620. hostInCertificate string
  621. hostInCertificateProvided bool
  622. serverSPN string
  623. workstation string
  624. appname string
  625. typeFlags uint8
  626. failOverPartner string
  627. failOverPort uint64
  628. packetSize uint16
  629. }
  630. func splitConnectionString(dsn string) (res map[string]string) {
  631. res = map[string]string{}
  632. parts := strings.Split(dsn, ";")
  633. for _, part := range parts {
  634. if len(part) == 0 {
  635. continue
  636. }
  637. lst := strings.SplitN(part, "=", 2)
  638. name := strings.TrimSpace(strings.ToLower(lst[0]))
  639. if len(name) == 0 {
  640. continue
  641. }
  642. var value string = ""
  643. if len(lst) > 1 {
  644. value = strings.TrimSpace(lst[1])
  645. }
  646. res[name] = value
  647. }
  648. return res
  649. }
  650. // Splits a URL in the ODBC format
  651. func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
  652. res := map[string]string{}
  653. type parserState int
  654. const (
  655. // Before the start of a key
  656. parserStateBeforeKey parserState = iota
  657. // Inside a key
  658. parserStateKey
  659. // Beginning of a value. May be bare or braced
  660. parserStateBeginValue
  661. // Inside a bare value
  662. parserStateBareValue
  663. // Inside a braced value
  664. parserStateBracedValue
  665. // A closing brace inside a braced value.
  666. // May be the end of the value or an escaped closing brace, depending on the next character
  667. parserStateBracedValueClosingBrace
  668. // After a value. Next character should be a semicolon or whitespace.
  669. parserStateEndValue
  670. )
  671. var state = parserStateBeforeKey
  672. var key string
  673. var value string
  674. for i, c := range dsn {
  675. switch state {
  676. case parserStateBeforeKey:
  677. switch {
  678. case c == '=':
  679. return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
  680. case !unicode.IsSpace(c) && c != ';':
  681. state = parserStateKey
  682. key += string(c)
  683. }
  684. case parserStateKey:
  685. switch c {
  686. case '=':
  687. key = normalizeOdbcKey(key)
  688. if len(key) == 0 {
  689. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  690. }
  691. state = parserStateBeginValue
  692. case ';':
  693. // Key without value
  694. key = normalizeOdbcKey(key)
  695. if len(key) == 0 {
  696. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  697. }
  698. res[key] = value
  699. key = ""
  700. value = ""
  701. state = parserStateBeforeKey
  702. default:
  703. key += string(c)
  704. }
  705. case parserStateBeginValue:
  706. switch {
  707. case c == '{':
  708. state = parserStateBracedValue
  709. case c == ';':
  710. // Empty value
  711. res[key] = value
  712. key = ""
  713. state = parserStateBeforeKey
  714. case unicode.IsSpace(c):
  715. // Ignore whitespace
  716. default:
  717. state = parserStateBareValue
  718. value += string(c)
  719. }
  720. case parserStateBareValue:
  721. if c == ';' {
  722. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  723. key = ""
  724. value = ""
  725. state = parserStateBeforeKey
  726. } else {
  727. value += string(c)
  728. }
  729. case parserStateBracedValue:
  730. if c == '}' {
  731. state = parserStateBracedValueClosingBrace
  732. } else {
  733. value += string(c)
  734. }
  735. case parserStateBracedValueClosingBrace:
  736. if c == '}' {
  737. // Escaped closing brace
  738. value += string(c)
  739. state = parserStateBracedValue
  740. continue
  741. }
  742. // End of braced value
  743. res[key] = value
  744. key = ""
  745. value = ""
  746. // This character is the first character past the end,
  747. // so it needs to be parsed like the parserStateEndValue state.
  748. state = parserStateEndValue
  749. switch {
  750. case c == ';':
  751. state = parserStateBeforeKey
  752. case unicode.IsSpace(c):
  753. // Ignore whitespace
  754. default:
  755. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  756. }
  757. case parserStateEndValue:
  758. switch {
  759. case c == ';':
  760. state = parserStateBeforeKey
  761. case unicode.IsSpace(c):
  762. // Ignore whitespace
  763. default:
  764. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  765. }
  766. }
  767. }
  768. switch state {
  769. case parserStateBeforeKey: // Okay
  770. case parserStateKey: // Unfinished key. Treat as key without value.
  771. key = normalizeOdbcKey(key)
  772. if len(key) == 0 {
  773. return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
  774. }
  775. res[key] = value
  776. case parserStateBeginValue: // Empty value
  777. res[key] = value
  778. case parserStateBareValue:
  779. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  780. case parserStateBracedValue:
  781. return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
  782. case parserStateBracedValueClosingBrace: // End of braced value
  783. res[key] = value
  784. case parserStateEndValue: // Okay
  785. }
  786. return res, nil
  787. }
  788. // Normalizes the given string as an ODBC-format key
  789. func normalizeOdbcKey(s string) string {
  790. return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
  791. }
  792. // Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
  793. func splitConnectionStringURL(dsn string) (map[string]string, error) {
  794. res := map[string]string{}
  795. u, err := url.Parse(dsn)
  796. if err != nil {
  797. return res, err
  798. }
  799. if u.Scheme != "sqlserver" {
  800. return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
  801. }
  802. if u.User != nil {
  803. res["user id"] = u.User.Username()
  804. p, exists := u.User.Password()
  805. if exists {
  806. res["password"] = p
  807. }
  808. }
  809. host, port, err := net.SplitHostPort(u.Host)
  810. if err != nil {
  811. host = u.Host
  812. }
  813. if len(u.Path) > 0 {
  814. res["server"] = host + "\\" + u.Path[1:]
  815. } else {
  816. res["server"] = host
  817. }
  818. if len(port) > 0 {
  819. res["port"] = port
  820. }
  821. query := u.Query()
  822. for k, v := range query {
  823. if len(v) > 1 {
  824. return res, fmt.Errorf("key %s provided more than once", k)
  825. }
  826. res[strings.ToLower(k)] = v[0]
  827. }
  828. return res, nil
  829. }
  830. func parseConnectParams(dsn string) (connectParams, error) {
  831. var p connectParams
  832. var params map[string]string
  833. if strings.HasPrefix(dsn, "odbc:") {
  834. parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
  835. if err != nil {
  836. return p, err
  837. }
  838. params = parameters
  839. } else if strings.HasPrefix(dsn, "sqlserver://") {
  840. parameters, err := splitConnectionStringURL(dsn)
  841. if err != nil {
  842. return p, err
  843. }
  844. params = parameters
  845. } else {
  846. params = splitConnectionString(dsn)
  847. }
  848. strlog, ok := params["log"]
  849. if ok {
  850. var err error
  851. p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
  852. if err != nil {
  853. return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
  854. }
  855. }
  856. server := params["server"]
  857. parts := strings.SplitN(server, `\`, 2)
  858. p.host = parts[0]
  859. if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
  860. p.host = "localhost"
  861. }
  862. if len(parts) > 1 {
  863. p.instance = parts[1]
  864. }
  865. p.database = params["database"]
  866. p.user = params["user id"]
  867. p.password = params["password"]
  868. p.port = 1433
  869. strport, ok := params["port"]
  870. if ok {
  871. var err error
  872. p.port, err = strconv.ParseUint(strport, 10, 16)
  873. if err != nil {
  874. f := "Invalid tcp port '%v': %v"
  875. return p, fmt.Errorf(f, strport, err.Error())
  876. }
  877. }
  878. // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
  879. // Default packet size remains at 4096 bytes
  880. p.packetSize = 4096
  881. strpsize, ok := params["packet size"]
  882. if ok {
  883. var err error
  884. psize, err := strconv.ParseUint(strpsize, 0, 16)
  885. if err != nil {
  886. f := "Invalid packet size '%v': %v"
  887. return p, fmt.Errorf(f, strpsize, err.Error())
  888. }
  889. // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
  890. // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
  891. // a higher packet size, the server will respond with an ENVCHANGE request to
  892. // alter the packet size to 16383 bytes.
  893. p.packetSize = uint16(psize)
  894. if p.packetSize < 512 {
  895. p.packetSize = 512
  896. } else if p.packetSize > 32767 {
  897. p.packetSize = 32767
  898. }
  899. }
  900. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  901. //
  902. // Do not set a connection timeout. Use Context to manage such things.
  903. // Default to zero, but still allow it to be set.
  904. if strconntimeout, ok := params["connection timeout"]; ok {
  905. timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
  906. if err != nil {
  907. f := "Invalid connection timeout '%v': %v"
  908. return p, fmt.Errorf(f, strconntimeout, err.Error())
  909. }
  910. p.conn_timeout = time.Duration(timeout) * time.Second
  911. }
  912. p.dial_timeout = 15 * time.Second
  913. if strdialtimeout, ok := params["dial timeout"]; ok {
  914. timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
  915. if err != nil {
  916. f := "Invalid dial timeout '%v': %v"
  917. return p, fmt.Errorf(f, strdialtimeout, err.Error())
  918. }
  919. p.dial_timeout = time.Duration(timeout) * time.Second
  920. }
  921. // default keep alive should be 30 seconds according to spec:
  922. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  923. p.keepAlive = 30 * time.Second
  924. if keepAlive, ok := params["keepalive"]; ok {
  925. timeout, err := strconv.ParseUint(keepAlive, 10, 64)
  926. if err != nil {
  927. f := "Invalid keepAlive value '%s': %s"
  928. return p, fmt.Errorf(f, keepAlive, err.Error())
  929. }
  930. p.keepAlive = time.Duration(timeout) * time.Second
  931. }
  932. encrypt, ok := params["encrypt"]
  933. if ok {
  934. if strings.EqualFold(encrypt, "DISABLE") {
  935. p.disableEncryption = true
  936. } else {
  937. var err error
  938. p.encrypt, err = strconv.ParseBool(encrypt)
  939. if err != nil {
  940. f := "Invalid encrypt '%s': %s"
  941. return p, fmt.Errorf(f, encrypt, err.Error())
  942. }
  943. }
  944. } else {
  945. p.trustServerCertificate = true
  946. }
  947. trust, ok := params["trustservercertificate"]
  948. if ok {
  949. var err error
  950. p.trustServerCertificate, err = strconv.ParseBool(trust)
  951. if err != nil {
  952. f := "Invalid trust server certificate '%s': %s"
  953. return p, fmt.Errorf(f, trust, err.Error())
  954. }
  955. }
  956. p.certificate = params["certificate"]
  957. p.hostInCertificate, ok = params["hostnameincertificate"]
  958. if ok {
  959. p.hostInCertificateProvided = true
  960. } else {
  961. p.hostInCertificate = p.host
  962. p.hostInCertificateProvided = false
  963. }
  964. serverSPN, ok := params["serverspn"]
  965. if ok {
  966. p.serverSPN = serverSPN
  967. } else {
  968. p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
  969. }
  970. workstation, ok := params["workstation id"]
  971. if ok {
  972. p.workstation = workstation
  973. } else {
  974. workstation, err := os.Hostname()
  975. if err == nil {
  976. p.workstation = workstation
  977. }
  978. }
  979. appname, ok := params["app name"]
  980. if !ok {
  981. appname = "go-mssqldb"
  982. }
  983. p.appname = appname
  984. appintent, ok := params["applicationintent"]
  985. if ok {
  986. if appintent == "ReadOnly" {
  987. p.typeFlags |= fReadOnlyIntent
  988. }
  989. }
  990. failOverPartner, ok := params["failoverpartner"]
  991. if ok {
  992. p.failOverPartner = failOverPartner
  993. }
  994. failOverPort, ok := params["failoverport"]
  995. if ok {
  996. var err error
  997. p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
  998. if err != nil {
  999. f := "Invalid tcp port '%v': %v"
  1000. return p, fmt.Errorf(f, failOverPort, err.Error())
  1001. }
  1002. }
  1003. return p, nil
  1004. }
  1005. type auth interface {
  1006. InitialBytes() ([]byte, error)
  1007. NextBytes([]byte) ([]byte, error)
  1008. Free()
  1009. }
  1010. // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
  1011. // list of IP addresses. So if there is more than one, try them all and
  1012. // use the first one that allows a connection.
  1013. func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
  1014. var ips []net.IP
  1015. ips, err = net.LookupIP(p.host)
  1016. if err != nil {
  1017. ip := net.ParseIP(p.host)
  1018. if ip == nil {
  1019. return nil, err
  1020. }
  1021. ips = []net.IP{ip}
  1022. }
  1023. if len(ips) == 1 {
  1024. d := c.getDialer(&p)
  1025. addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
  1026. conn, err = d.DialContext(ctx, "tcp", addr)
  1027. } else {
  1028. //Try Dials in parallel to avoid waiting for timeouts.
  1029. connChan := make(chan net.Conn, len(ips))
  1030. errChan := make(chan error, len(ips))
  1031. portStr := strconv.Itoa(int(p.port))
  1032. for _, ip := range ips {
  1033. go func(ip net.IP) {
  1034. d := c.getDialer(&p)
  1035. addr := net.JoinHostPort(ip.String(), portStr)
  1036. conn, err := d.DialContext(ctx, "tcp", addr)
  1037. if err == nil {
  1038. connChan <- conn
  1039. } else {
  1040. errChan <- err
  1041. }
  1042. }(ip)
  1043. }
  1044. // Wait for either the *first* successful connection, or all the errors
  1045. wait_loop:
  1046. for i, _ := range ips {
  1047. select {
  1048. case conn = <-connChan:
  1049. // Got a connection to use, close any others
  1050. go func(n int) {
  1051. for i := 0; i < n; i++ {
  1052. select {
  1053. case conn := <-connChan:
  1054. conn.Close()
  1055. case <-errChan:
  1056. }
  1057. }
  1058. }(len(ips) - i - 1)
  1059. // Remove any earlier errors we may have collected
  1060. err = nil
  1061. break wait_loop
  1062. case err = <-errChan:
  1063. }
  1064. }
  1065. }
  1066. // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
  1067. if conn == nil {
  1068. f := "Unable to open tcp connection with host '%v:%v': %v"
  1069. return nil, fmt.Errorf(f, p.host, p.port, err.Error())
  1070. }
  1071. return conn, err
  1072. }
  1073. func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
  1074. dialCtx := ctx
  1075. if p.dial_timeout > 0 {
  1076. var cancel func()
  1077. dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
  1078. defer cancel()
  1079. }
  1080. // if instance is specified use instance resolution service
  1081. if p.instance != "" {
  1082. p.instance = strings.ToUpper(p.instance)
  1083. d := c.getDialer(&p)
  1084. instances, err := getInstances(dialCtx, d, p.host)
  1085. if err != nil {
  1086. f := "Unable to get instances from Sql Server Browser on host %v: %v"
  1087. return nil, fmt.Errorf(f, p.host, err.Error())
  1088. }
  1089. strport, ok := instances[p.instance]["tcp"]
  1090. if !ok {
  1091. f := "No instance matching '%v' returned from host '%v'"
  1092. return nil, fmt.Errorf(f, p.instance, p.host)
  1093. }
  1094. p.port, err = strconv.ParseUint(strport, 0, 16)
  1095. if err != nil {
  1096. f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
  1097. return nil, fmt.Errorf(f, strport, err.Error())
  1098. }
  1099. }
  1100. initiate_connection:
  1101. conn, err := dialConnection(dialCtx, c, p)
  1102. if err != nil {
  1103. return nil, err
  1104. }
  1105. toconn := newTimeoutConn(conn, p.conn_timeout)
  1106. outbuf := newTdsBuffer(p.packetSize, toconn)
  1107. sess := tdsSession{
  1108. buf: outbuf,
  1109. log: log,
  1110. logFlags: p.logFlags,
  1111. }
  1112. instance_buf := []byte(p.instance)
  1113. instance_buf = append(instance_buf, 0) // zero terminate instance name
  1114. var encrypt byte
  1115. if p.disableEncryption {
  1116. encrypt = encryptNotSup
  1117. } else if p.encrypt {
  1118. encrypt = encryptOn
  1119. } else {
  1120. encrypt = encryptOff
  1121. }
  1122. fields := map[uint8][]byte{
  1123. preloginVERSION: {0, 0, 0, 0, 0, 0},
  1124. preloginENCRYPTION: {encrypt},
  1125. preloginINSTOPT: instance_buf,
  1126. preloginTHREADID: {0, 0, 0, 0},
  1127. preloginMARS: {0}, // MARS disabled
  1128. }
  1129. err = writePrelogin(outbuf, fields)
  1130. if err != nil {
  1131. return nil, err
  1132. }
  1133. fields, err = readPrelogin(outbuf)
  1134. if err != nil {
  1135. return nil, err
  1136. }
  1137. encryptBytes, ok := fields[preloginENCRYPTION]
  1138. if !ok {
  1139. return nil, fmt.Errorf("Encrypt negotiation failed")
  1140. }
  1141. encrypt = encryptBytes[0]
  1142. if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
  1143. return nil, fmt.Errorf("Server does not support encryption")
  1144. }
  1145. if encrypt != encryptNotSup {
  1146. var config tls.Config
  1147. if p.certificate != "" {
  1148. pem, err := ioutil.ReadFile(p.certificate)
  1149. if err != nil {
  1150. return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
  1151. }
  1152. certs := x509.NewCertPool()
  1153. certs.AppendCertsFromPEM(pem)
  1154. config.RootCAs = certs
  1155. }
  1156. if p.trustServerCertificate {
  1157. config.InsecureSkipVerify = true
  1158. }
  1159. config.ServerName = p.hostInCertificate
  1160. // fix for https://github.com/denisenkom/go-mssqldb/issues/166
  1161. // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
  1162. // while SQL Server seems to expect one TCP segment per encrypted TDS package.
  1163. // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
  1164. config.DynamicRecordSizingDisabled = true
  1165. outbuf.transport = conn
  1166. toconn.buf = outbuf
  1167. tlsConn := tls.Client(toconn, &config)
  1168. err = tlsConn.Handshake()
  1169. toconn.buf = nil
  1170. outbuf.transport = tlsConn
  1171. if err != nil {
  1172. return nil, fmt.Errorf("TLS Handshake failed: %v", err)
  1173. }
  1174. if encrypt == encryptOff {
  1175. outbuf.afterFirst = func() {
  1176. outbuf.transport = toconn
  1177. }
  1178. }
  1179. }
  1180. login := login{
  1181. TDSVersion: verTDS74,
  1182. PacketSize: uint32(outbuf.PackageSize()),
  1183. Database: p.database,
  1184. OptionFlags2: fODBC, // to get unlimited TEXTSIZE
  1185. HostName: p.workstation,
  1186. ServerName: p.host,
  1187. AppName: p.appname,
  1188. TypeFlags: p.typeFlags,
  1189. }
  1190. auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
  1191. if auth_ok {
  1192. login.SSPI, err = auth.InitialBytes()
  1193. if err != nil {
  1194. return nil, err
  1195. }
  1196. login.OptionFlags2 |= fIntSecurity
  1197. defer auth.Free()
  1198. } else {
  1199. login.UserName = p.user
  1200. login.Password = p.password
  1201. }
  1202. err = sendLogin(outbuf, login)
  1203. if err != nil {
  1204. return nil, err
  1205. }
  1206. // processing login response
  1207. success := false
  1208. for {
  1209. tokchan := make(chan tokenStruct, 5)
  1210. go processResponse(context.Background(), &sess, tokchan, nil)
  1211. for tok := range tokchan {
  1212. switch token := tok.(type) {
  1213. case sspiMsg:
  1214. sspi_msg, err := auth.NextBytes(token)
  1215. if err != nil {
  1216. return nil, err
  1217. }
  1218. if sspi_msg != nil && len(sspi_msg) > 0 {
  1219. outbuf.BeginPacket(packSSPIMessage, false)
  1220. _, err = outbuf.Write(sspi_msg)
  1221. if err != nil {
  1222. return nil, err
  1223. }
  1224. err = outbuf.FinishPacket()
  1225. if err != nil {
  1226. return nil, err
  1227. }
  1228. sspi_msg = nil
  1229. }
  1230. case loginAckStruct:
  1231. success = true
  1232. sess.loginAck = token
  1233. case error:
  1234. return nil, fmt.Errorf("Login error: %s", token.Error())
  1235. case doneStruct:
  1236. if token.isError() {
  1237. return nil, fmt.Errorf("Login error: %s", token.getError())
  1238. }
  1239. goto loginEnd
  1240. }
  1241. }
  1242. }
  1243. loginEnd:
  1244. if !success {
  1245. return nil, fmt.Errorf("Login failed")
  1246. }
  1247. if sess.routedServer != "" {
  1248. toconn.Close()
  1249. p.host = sess.routedServer
  1250. p.port = uint64(sess.routedPort)
  1251. if !p.hostInCertificateProvided {
  1252. p.hostInCertificate = sess.routedServer
  1253. }
  1254. goto initiate_connection
  1255. }
  1256. return &sess, nil
  1257. }