You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

negotiator.go 3.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package ntlmssp
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "io"
  6. "io/ioutil"
  7. "net/http"
  8. "strings"
  9. )
  10. // GetDomain : parse domain name from based on slashes in the input
  11. func GetDomain(user string) (string, string) {
  12. domain := ""
  13. if strings.Contains(user, "\\") {
  14. ucomponents := strings.SplitN(user, "\\", 2)
  15. domain = ucomponents[0]
  16. user = ucomponents[1]
  17. }
  18. return user, domain
  19. }
  20. //Negotiator is a http.Roundtripper decorator that automatically
  21. //converts basic authentication to NTLM/Negotiate authentication when appropriate.
  22. type Negotiator struct{ http.RoundTripper }
  23. //RoundTrip sends the request to the server, handling any authentication
  24. //re-sends as needed.
  25. func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
  26. // Use default round tripper if not provided
  27. rt := l.RoundTripper
  28. if rt == nil {
  29. rt = http.DefaultTransport
  30. }
  31. // If it is not basic auth, just round trip the request as usual
  32. reqauth := authheader(req.Header.Get("Authorization"))
  33. if !reqauth.IsBasic() {
  34. return rt.RoundTrip(req)
  35. }
  36. // Save request body
  37. body := bytes.Buffer{}
  38. if req.Body != nil {
  39. _, err = body.ReadFrom(req.Body)
  40. if err != nil {
  41. return nil, err
  42. }
  43. req.Body.Close()
  44. req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
  45. }
  46. // first try anonymous, in case the server still finds us
  47. // authenticated from previous traffic
  48. req.Header.Del("Authorization")
  49. res, err = rt.RoundTrip(req)
  50. if err != nil {
  51. return nil, err
  52. }
  53. if res.StatusCode != http.StatusUnauthorized {
  54. return res, err
  55. }
  56. resauth := authheader(res.Header.Get("Www-Authenticate"))
  57. if !resauth.IsNegotiate() && !resauth.IsNTLM() {
  58. // Unauthorized, Negotiate not requested, let's try with basic auth
  59. req.Header.Set("Authorization", string(reqauth))
  60. io.Copy(ioutil.Discard, res.Body)
  61. res.Body.Close()
  62. req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
  63. res, err = rt.RoundTrip(req)
  64. if err != nil {
  65. return nil, err
  66. }
  67. if res.StatusCode != http.StatusUnauthorized {
  68. return res, err
  69. }
  70. resauth = authheader(res.Header.Get("Www-Authenticate"))
  71. }
  72. if resauth.IsNegotiate() || resauth.IsNTLM() {
  73. // 401 with request:Basic and response:Negotiate
  74. io.Copy(ioutil.Discard, res.Body)
  75. res.Body.Close()
  76. // recycle credentials
  77. u, p, err := reqauth.GetBasicCreds()
  78. if err != nil {
  79. return nil, err
  80. }
  81. // get domain from username
  82. domain := ""
  83. u, domain = GetDomain(u)
  84. // send negotiate
  85. negotiateMessage, err := NewNegotiateMessage(domain, "")
  86. if err != nil {
  87. return nil, err
  88. }
  89. if resauth.IsNTLM() {
  90. req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
  91. } else {
  92. req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
  93. }
  94. req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
  95. res, err = rt.RoundTrip(req)
  96. if err != nil {
  97. return nil, err
  98. }
  99. // receive challenge?
  100. resauth = authheader(res.Header.Get("Www-Authenticate"))
  101. challengeMessage, err := resauth.GetData()
  102. if err != nil {
  103. return nil, err
  104. }
  105. if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
  106. // Negotiation failed, let client deal with response
  107. return res, nil
  108. }
  109. io.Copy(ioutil.Discard, res.Body)
  110. res.Body.Close()
  111. // send authenticate
  112. authenticateMessage, err := ProcessChallenge(challengeMessage, u, p)
  113. if err != nil {
  114. return nil, err
  115. }
  116. if resauth.IsNTLM() {
  117. req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
  118. } else {
  119. req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
  120. }
  121. req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
  122. return rt.RoundTrip(req)
  123. }
  124. return res, err
  125. }