You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

conn_str.go 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. package mssql
  2. import (
  3. "fmt"
  4. "net"
  5. "net/url"
  6. "os"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "unicode"
  11. )
  12. const defaultServerPort = 1433
  13. type connectParams struct {
  14. logFlags uint64
  15. port uint64
  16. host string
  17. instance string
  18. database string
  19. user string
  20. password string
  21. dial_timeout time.Duration
  22. conn_timeout time.Duration
  23. keepAlive time.Duration
  24. encrypt bool
  25. disableEncryption bool
  26. trustServerCertificate bool
  27. certificate string
  28. hostInCertificate string
  29. hostInCertificateProvided bool
  30. serverSPN string
  31. workstation string
  32. appname string
  33. typeFlags uint8
  34. failOverPartner string
  35. failOverPort uint64
  36. packetSize uint16
  37. fedAuthAccessToken string
  38. }
  39. func parseConnectParams(dsn string) (connectParams, error) {
  40. var p connectParams
  41. var params map[string]string
  42. if strings.HasPrefix(dsn, "odbc:") {
  43. parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
  44. if err != nil {
  45. return p, err
  46. }
  47. params = parameters
  48. } else if strings.HasPrefix(dsn, "sqlserver://") {
  49. parameters, err := splitConnectionStringURL(dsn)
  50. if err != nil {
  51. return p, err
  52. }
  53. params = parameters
  54. } else {
  55. params = splitConnectionString(dsn)
  56. }
  57. strlog, ok := params["log"]
  58. if ok {
  59. var err error
  60. p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
  61. if err != nil {
  62. return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
  63. }
  64. }
  65. server := params["server"]
  66. parts := strings.SplitN(server, `\`, 2)
  67. p.host = parts[0]
  68. if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
  69. p.host = "localhost"
  70. }
  71. if len(parts) > 1 {
  72. p.instance = parts[1]
  73. }
  74. p.database = params["database"]
  75. p.user = params["user id"]
  76. p.password = params["password"]
  77. p.port = 0
  78. strport, ok := params["port"]
  79. if ok {
  80. var err error
  81. p.port, err = strconv.ParseUint(strport, 10, 16)
  82. if err != nil {
  83. f := "Invalid tcp port '%v': %v"
  84. return p, fmt.Errorf(f, strport, err.Error())
  85. }
  86. }
  87. // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
  88. // Default packet size remains at 4096 bytes
  89. p.packetSize = 4096
  90. strpsize, ok := params["packet size"]
  91. if ok {
  92. var err error
  93. psize, err := strconv.ParseUint(strpsize, 0, 16)
  94. if err != nil {
  95. f := "Invalid packet size '%v': %v"
  96. return p, fmt.Errorf(f, strpsize, err.Error())
  97. }
  98. // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
  99. // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
  100. // a higher packet size, the server will respond with an ENVCHANGE request to
  101. // alter the packet size to 16383 bytes.
  102. p.packetSize = uint16(psize)
  103. if p.packetSize < 512 {
  104. p.packetSize = 512
  105. } else if p.packetSize > 32767 {
  106. p.packetSize = 32767
  107. }
  108. }
  109. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  110. //
  111. // Do not set a connection timeout. Use Context to manage such things.
  112. // Default to zero, but still allow it to be set.
  113. if strconntimeout, ok := params["connection timeout"]; ok {
  114. timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
  115. if err != nil {
  116. f := "Invalid connection timeout '%v': %v"
  117. return p, fmt.Errorf(f, strconntimeout, err.Error())
  118. }
  119. p.conn_timeout = time.Duration(timeout) * time.Second
  120. }
  121. p.dial_timeout = 15 * time.Second
  122. if strdialtimeout, ok := params["dial timeout"]; ok {
  123. timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
  124. if err != nil {
  125. f := "Invalid dial timeout '%v': %v"
  126. return p, fmt.Errorf(f, strdialtimeout, err.Error())
  127. }
  128. p.dial_timeout = time.Duration(timeout) * time.Second
  129. }
  130. // default keep alive should be 30 seconds according to spec:
  131. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  132. p.keepAlive = 30 * time.Second
  133. if keepAlive, ok := params["keepalive"]; ok {
  134. timeout, err := strconv.ParseUint(keepAlive, 10, 64)
  135. if err != nil {
  136. f := "Invalid keepAlive value '%s': %s"
  137. return p, fmt.Errorf(f, keepAlive, err.Error())
  138. }
  139. p.keepAlive = time.Duration(timeout) * time.Second
  140. }
  141. encrypt, ok := params["encrypt"]
  142. if ok {
  143. if strings.EqualFold(encrypt, "DISABLE") {
  144. p.disableEncryption = true
  145. } else {
  146. var err error
  147. p.encrypt, err = strconv.ParseBool(encrypt)
  148. if err != nil {
  149. f := "Invalid encrypt '%s': %s"
  150. return p, fmt.Errorf(f, encrypt, err.Error())
  151. }
  152. }
  153. } else {
  154. p.trustServerCertificate = true
  155. }
  156. trust, ok := params["trustservercertificate"]
  157. if ok {
  158. var err error
  159. p.trustServerCertificate, err = strconv.ParseBool(trust)
  160. if err != nil {
  161. f := "Invalid trust server certificate '%s': %s"
  162. return p, fmt.Errorf(f, trust, err.Error())
  163. }
  164. }
  165. p.certificate = params["certificate"]
  166. p.hostInCertificate, ok = params["hostnameincertificate"]
  167. if ok {
  168. p.hostInCertificateProvided = true
  169. } else {
  170. p.hostInCertificate = p.host
  171. p.hostInCertificateProvided = false
  172. }
  173. serverSPN, ok := params["serverspn"]
  174. if ok {
  175. p.serverSPN = serverSPN
  176. } else {
  177. p.serverSPN = generateSpn(p.host, resolveServerPort(p.port))
  178. }
  179. workstation, ok := params["workstation id"]
  180. if ok {
  181. p.workstation = workstation
  182. } else {
  183. workstation, err := os.Hostname()
  184. if err == nil {
  185. p.workstation = workstation
  186. }
  187. }
  188. appname, ok := params["app name"]
  189. if !ok {
  190. appname = "go-mssqldb"
  191. }
  192. p.appname = appname
  193. appintent, ok := params["applicationintent"]
  194. if ok {
  195. if appintent == "ReadOnly" {
  196. if p.database == "" {
  197. return p, fmt.Errorf("Database must be specified when ApplicationIntent is ReadOnly")
  198. }
  199. p.typeFlags |= fReadOnlyIntent
  200. }
  201. }
  202. failOverPartner, ok := params["failoverpartner"]
  203. if ok {
  204. p.failOverPartner = failOverPartner
  205. }
  206. failOverPort, ok := params["failoverport"]
  207. if ok {
  208. var err error
  209. p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
  210. if err != nil {
  211. f := "Invalid tcp port '%v': %v"
  212. return p, fmt.Errorf(f, failOverPort, err.Error())
  213. }
  214. }
  215. return p, nil
  216. }
  217. func splitConnectionString(dsn string) (res map[string]string) {
  218. res = map[string]string{}
  219. parts := strings.Split(dsn, ";")
  220. for _, part := range parts {
  221. if len(part) == 0 {
  222. continue
  223. }
  224. lst := strings.SplitN(part, "=", 2)
  225. name := strings.TrimSpace(strings.ToLower(lst[0]))
  226. if len(name) == 0 {
  227. continue
  228. }
  229. var value string = ""
  230. if len(lst) > 1 {
  231. value = strings.TrimSpace(lst[1])
  232. }
  233. res[name] = value
  234. }
  235. return res
  236. }
  237. // Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
  238. func splitConnectionStringURL(dsn string) (map[string]string, error) {
  239. res := map[string]string{}
  240. u, err := url.Parse(dsn)
  241. if err != nil {
  242. return res, err
  243. }
  244. if u.Scheme != "sqlserver" {
  245. return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
  246. }
  247. if u.User != nil {
  248. res["user id"] = u.User.Username()
  249. p, exists := u.User.Password()
  250. if exists {
  251. res["password"] = p
  252. }
  253. }
  254. host, port, err := net.SplitHostPort(u.Host)
  255. if err != nil {
  256. host = u.Host
  257. }
  258. if len(u.Path) > 0 {
  259. res["server"] = host + "\\" + u.Path[1:]
  260. } else {
  261. res["server"] = host
  262. }
  263. if len(port) > 0 {
  264. res["port"] = port
  265. }
  266. query := u.Query()
  267. for k, v := range query {
  268. if len(v) > 1 {
  269. return res, fmt.Errorf("key %s provided more than once", k)
  270. }
  271. res[strings.ToLower(k)] = v[0]
  272. }
  273. return res, nil
  274. }
  275. // Splits a URL in the ODBC format
  276. func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
  277. res := map[string]string{}
  278. type parserState int
  279. const (
  280. // Before the start of a key
  281. parserStateBeforeKey parserState = iota
  282. // Inside a key
  283. parserStateKey
  284. // Beginning of a value. May be bare or braced
  285. parserStateBeginValue
  286. // Inside a bare value
  287. parserStateBareValue
  288. // Inside a braced value
  289. parserStateBracedValue
  290. // A closing brace inside a braced value.
  291. // May be the end of the value or an escaped closing brace, depending on the next character
  292. parserStateBracedValueClosingBrace
  293. // After a value. Next character should be a semicolon or whitespace.
  294. parserStateEndValue
  295. )
  296. var state = parserStateBeforeKey
  297. var key string
  298. var value string
  299. for i, c := range dsn {
  300. switch state {
  301. case parserStateBeforeKey:
  302. switch {
  303. case c == '=':
  304. return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
  305. case !unicode.IsSpace(c) && c != ';':
  306. state = parserStateKey
  307. key += string(c)
  308. }
  309. case parserStateKey:
  310. switch c {
  311. case '=':
  312. key = normalizeOdbcKey(key)
  313. state = parserStateBeginValue
  314. case ';':
  315. // Key without value
  316. key = normalizeOdbcKey(key)
  317. res[key] = value
  318. key = ""
  319. value = ""
  320. state = parserStateBeforeKey
  321. default:
  322. key += string(c)
  323. }
  324. case parserStateBeginValue:
  325. switch {
  326. case c == '{':
  327. state = parserStateBracedValue
  328. case c == ';':
  329. // Empty value
  330. res[key] = value
  331. key = ""
  332. state = parserStateBeforeKey
  333. case unicode.IsSpace(c):
  334. // Ignore whitespace
  335. default:
  336. state = parserStateBareValue
  337. value += string(c)
  338. }
  339. case parserStateBareValue:
  340. if c == ';' {
  341. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  342. key = ""
  343. value = ""
  344. state = parserStateBeforeKey
  345. } else {
  346. value += string(c)
  347. }
  348. case parserStateBracedValue:
  349. if c == '}' {
  350. state = parserStateBracedValueClosingBrace
  351. } else {
  352. value += string(c)
  353. }
  354. case parserStateBracedValueClosingBrace:
  355. if c == '}' {
  356. // Escaped closing brace
  357. value += string(c)
  358. state = parserStateBracedValue
  359. continue
  360. }
  361. // End of braced value
  362. res[key] = value
  363. key = ""
  364. value = ""
  365. // This character is the first character past the end,
  366. // so it needs to be parsed like the parserStateEndValue state.
  367. state = parserStateEndValue
  368. switch {
  369. case c == ';':
  370. state = parserStateBeforeKey
  371. case unicode.IsSpace(c):
  372. // Ignore whitespace
  373. default:
  374. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  375. }
  376. case parserStateEndValue:
  377. switch {
  378. case c == ';':
  379. state = parserStateBeforeKey
  380. case unicode.IsSpace(c):
  381. // Ignore whitespace
  382. default:
  383. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  384. }
  385. }
  386. }
  387. switch state {
  388. case parserStateBeforeKey: // Okay
  389. case parserStateKey: // Unfinished key. Treat as key without value.
  390. key = normalizeOdbcKey(key)
  391. res[key] = value
  392. case parserStateBeginValue: // Empty value
  393. res[key] = value
  394. case parserStateBareValue:
  395. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  396. case parserStateBracedValue:
  397. return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
  398. case parserStateBracedValueClosingBrace: // End of braced value
  399. res[key] = value
  400. case parserStateEndValue: // Okay
  401. }
  402. return res, nil
  403. }
  404. // Normalizes the given string as an ODBC-format key
  405. func normalizeOdbcKey(s string) string {
  406. return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
  407. }
  408. func resolveServerPort(port uint64) uint64 {
  409. if port == 0 {
  410. return defaultServerPort
  411. }
  412. return port
  413. }
  414. func generateSpn(host string, port uint64) string {
  415. return fmt.Sprintf("MSSQLSvc/%s:%d", host, port)
  416. }