diff options
Diffstat (limited to 'vendor/github.com/mrjones/oauth/oauth.go')
-rw-r--r-- | vendor/github.com/mrjones/oauth/oauth.go | 1412 |
1 files changed, 1412 insertions, 0 deletions
diff --git a/vendor/github.com/mrjones/oauth/oauth.go b/vendor/github.com/mrjones/oauth/oauth.go new file mode 100644 index 0000000000..95eee64abd --- /dev/null +++ b/vendor/github.com/mrjones/oauth/oauth.go @@ -0,0 +1,1412 @@ +// OAuth 1.0 consumer implementation. +// See http://www.oauth.net and RFC 5849 +// +// There are typically three parties involved in an OAuth exchange: +// (1) The "Service Provider" (e.g. Google, Twitter, NetFlix) who operates the +// service where the data resides. +// (2) The "End User" who owns that data, and wants to grant access to a third-party. +// (3) That third-party who wants access to the data (after first being authorized by +// the user). This third-party is referred to as the "Consumer" in OAuth +// terminology. +// +// This library is designed to help implement the third-party consumer by handling the +// low-level authentication tasks, and allowing for authenticated requests to the +// service provider on behalf of the user. +// +// Caveats: +// - Currently only supports HMAC and RSA signatures. +// - Currently only supports SHA1 and SHA256 hashes. +// - Currently only supports OAuth 1.0 +// +// Overview of how to use this library: +// (1) First create a new Consumer instance with the NewConsumer function +// (2) Get a RequestToken, and "authorization url" from GetRequestTokenAndUrl() +// (3) Save the RequestToken, you will need it again in step 6. +// (4) Redirect the user to the "authorization url" from step 2, where they will +// authorize your access to the service provider. +// (5) Wait. You will be called back on the CallbackUrl that you provide, and you +// will recieve a "verification code". +// (6) Call AuthorizeToken() with the RequestToken from step 2 and the +// "verification code" from step 5. +// (7) You will get back an AccessToken. Save this for as long as you need access +// to the user's data, and treat it like a password; it is a secret. +// (8) You can now throw away the RequestToken from step 2, it is no longer +// necessary. +// (9) Call "MakeHttpClient" using the AccessToken from step 7 to get an +// HTTP client which can access protected resources. +package oauth + +import ( + "bytes" + "crypto" + "crypto/hmac" + cryptoRand "crypto/rand" + "crypto/rsa" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "mime/multipart" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +const ( + OAUTH_VERSION = "1.0" + SIGNATURE_METHOD_HMAC = "HMAC-" + SIGNATURE_METHOD_RSA = "RSA-" + + HTTP_AUTH_HEADER = "Authorization" + OAUTH_HEADER = "OAuth " + BODY_HASH_PARAM = "oauth_body_hash" + CALLBACK_PARAM = "oauth_callback" + CONSUMER_KEY_PARAM = "oauth_consumer_key" + NONCE_PARAM = "oauth_nonce" + SESSION_HANDLE_PARAM = "oauth_session_handle" + SIGNATURE_METHOD_PARAM = "oauth_signature_method" + SIGNATURE_PARAM = "oauth_signature" + TIMESTAMP_PARAM = "oauth_timestamp" + TOKEN_PARAM = "oauth_token" + TOKEN_SECRET_PARAM = "oauth_token_secret" + VERIFIER_PARAM = "oauth_verifier" + VERSION_PARAM = "oauth_version" +) + +var HASH_METHOD_MAP = map[crypto.Hash]string{ + crypto.SHA1: "SHA1", + crypto.SHA256: "SHA256", +} + +// TODO(mrjones) Do we definitely want separate "Request" and "Access" token classes? +// They're identical structurally, but used for different purposes. +type RequestToken struct { + Token string + Secret string +} + +type AccessToken struct { + Token string + Secret string + AdditionalData map[string]string +} + +type DataLocation int + +const ( + LOC_BODY DataLocation = iota + 1 + LOC_URL + LOC_MULTIPART + LOC_JSON + LOC_XML +) + +// Information about how to contact the service provider (see #1 above). +// You usually find all of these URLs by reading the documentation for the service +// that you're trying to connect to. +// Some common examples are: +// (1) Google, standard APIs: +// http://code.google.com/apis/accounts/docs/OAuth_ref.html +// - RequestTokenUrl: https://www.google.com/accounts/OAuthGetRequestToken +// - AuthorizeTokenUrl: https://www.google.com/accounts/OAuthAuthorizeToken +// - AccessTokenUrl: https://www.google.com/accounts/OAuthGetAccessToken +// Note: Some Google APIs (for example, Google Latitude) use different values for +// one or more of those URLs. +// (2) Twitter API: +// http://dev.twitter.com/pages/auth +// - RequestTokenUrl: http://api.twitter.com/oauth/request_token +// - AuthorizeTokenUrl: https://api.twitter.com/oauth/authorize +// - AccessTokenUrl: https://api.twitter.com/oauth/access_token +// (3) NetFlix API: +// http://developer.netflix.com/docs/Security +// - RequestTokenUrl: http://api.netflix.com/oauth/request_token +// - AuthroizeTokenUrl: https://api-user.netflix.com/oauth/login +// - AccessTokenUrl: http://api.netflix.com/oauth/access_token +// Set HttpMethod if the service provider requires a different HTTP method +// to be used for OAuth token requests +type ServiceProvider struct { + RequestTokenUrl string + AuthorizeTokenUrl string + AccessTokenUrl string + HttpMethod string + BodyHash bool + IgnoreTimestamp bool + + // Enables non spec-compliant behavior: + // Allow parameters to be passed in the query string rather + // than the body. + // See https://github.com/mrjones/oauth/pull/63 + SignQueryParams bool +} + +func (sp *ServiceProvider) httpMethod() string { + if sp.HttpMethod != "" { + return sp.HttpMethod + } + + return "GET" +} + +// lockedNonceGenerator wraps a non-reentrant random number generator with a +// lock +type lockedNonceGenerator struct { + nonceGenerator nonceGenerator + lock sync.Mutex +} + +func newLockedNonceGenerator(c clock) *lockedNonceGenerator { + return &lockedNonceGenerator{ + nonceGenerator: rand.New(rand.NewSource(c.Nanos())), + } +} + +func (n *lockedNonceGenerator) Int63() int64 { + n.lock.Lock() + r := n.nonceGenerator.Int63() + n.lock.Unlock() + return r +} + +// Consumers are stateless, you can call the various methods (GetRequestTokenAndUrl, +// AuthorizeToken, and Get) on various different instances of Consumers *as long as +// they were set up in the same way.* It is up to you, as the caller to persist the +// necessary state (RequestTokens and AccessTokens). +type Consumer struct { + // Some ServiceProviders require extra parameters to be passed for various reasons. + // For example Google APIs require you to set a scope= parameter to specify how much + // access is being granted. The proper values for scope= depend on the service: + // For more, see: http://code.google.com/apis/accounts/docs/OAuth.html#prepScope + AdditionalParams map[string]string + + // The rest of this class is configured via the NewConsumer function. + consumerKey string + serviceProvider ServiceProvider + + // Some APIs (e.g. Netflix) aren't quite standard OAuth, and require passing + // additional parameters when authorizing the request token. For most APIs + // this field can be ignored. For Netflix, do something like: + // consumer.AdditionalAuthorizationUrlParams = map[string]string{ + // "application_name": "YourAppName", + // "oauth_consumer_key": "YourConsumerKey", + // } + AdditionalAuthorizationUrlParams map[string]string + + debug bool + + // Defaults to http.Client{}, can be overridden (e.g. for testing) as necessary + HttpClient HttpClient + + // Some APIs (e.g. Intuit/Quickbooks) require sending additional headers along with + // requests. (like "Accept" to specify the response type as XML or JSON) Note that this + // will only *add* headers, not set existing ones. + AdditionalHeaders map[string][]string + + // Private seams for mocking dependencies when testing + clock clock + // Seeded generators are not reentrant + nonceGenerator nonceGenerator + signer signer +} + +func newConsumer(consumerKey string, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { + clock := &defaultClock{} + if httpClient == nil { + httpClient = &http.Client{} + } + return &Consumer{ + consumerKey: consumerKey, + serviceProvider: serviceProvider, + clock: clock, + HttpClient: httpClient, + nonceGenerator: newLockedNonceGenerator(clock), + + AdditionalParams: make(map[string]string), + AdditionalAuthorizationUrlParams: make(map[string]string), + } +} + +// Creates a new Consumer instance, with a HMAC-SHA1 signer +// - consumerKey and consumerSecret: +// values you should obtain from the ServiceProvider when you register your +// application. +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +func NewConsumer(consumerKey string, consumerSecret string, + serviceProvider ServiceProvider) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, nil) + + consumer.signer = &HMACSigner{ + consumerSecret: consumerSecret, + hashFunc: crypto.SHA1, + } + + return consumer +} + +// Creates a new Consumer instance, with a HMAC-SHA1 signer +// - consumerKey and consumerSecret: +// values you should obtain from the ServiceProvider when you register your +// application. +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +// - httpClient: +// Provides a custom implementation of the httpClient used under the hood +// to make the request. This is especially useful if you want to use +// Google App Engine. +// +func NewCustomHttpClientConsumer(consumerKey string, consumerSecret string, + serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, httpClient) + + consumer.signer = &HMACSigner{ + consumerSecret: consumerSecret, + hashFunc: crypto.SHA1, + } + + return consumer +} + +// Creates a new Consumer instance, with a HMAC signer +// - consumerKey and consumerSecret: +// values you should obtain from the ServiceProvider when you register your +// application. +// +// - hashFunc: +// the crypto.Hash to use for signatures +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +// - httpClient: +// Provides a custom implementation of the httpClient used under the hood +// to make the request. This is especially useful if you want to use +// Google App Engine. Can be nil for default. +// +func NewCustomConsumer(consumerKey string, consumerSecret string, + hashFunc crypto.Hash, serviceProvider ServiceProvider, + httpClient *http.Client) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, httpClient) + + consumer.signer = &HMACSigner{ + consumerSecret: consumerSecret, + hashFunc: hashFunc, + } + + return consumer +} + +// Creates a new Consumer instance, with a RSA-SHA1 signer +// - consumerKey: +// value you should obtain from the ServiceProvider when you register your +// application. +// +// - privateKey: +// the private key to use for signatures +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +func NewRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey, + serviceProvider ServiceProvider) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, nil) + + consumer.signer = &RSASigner{ + privateKey: privateKey, + hashFunc: crypto.SHA1, + rand: cryptoRand.Reader, + } + + return consumer +} + +// Creates a new Consumer instance, with a RSA signer +// - consumerKey: +// value you should obtain from the ServiceProvider when you register your +// application. +// +// - privateKey: +// the private key to use for signatures +// +// - hashFunc: +// the crypto.Hash to use for signatures +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +// - httpClient: +// Provides a custom implementation of the httpClient used under the hood +// to make the request. This is especially useful if you want to use +// Google App Engine. Can be nil for default. +// +func NewCustomRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey, + hashFunc crypto.Hash, serviceProvider ServiceProvider, + httpClient *http.Client) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, httpClient) + + consumer.signer = &RSASigner{ + privateKey: privateKey, + hashFunc: hashFunc, + rand: cryptoRand.Reader, + } + + return consumer +} + +// Kicks off the OAuth authorization process. +// - callbackUrl: +// Authorizing a token *requires* redirecting to the service provider. This is the +// URL which the service provider will redirect the user back to after that +// authorization is completed. The service provider will pass back a verification +// code which is necessary to complete the rest of the process (in AuthorizeToken). +// Notes on callbackUrl: +// - Some (all?) service providers allow for setting "oob" (for out-of-band) as a +// callback url. If this is set the service provider will present the +// verification code directly to the user, and you must provide a place for +// them to copy-and-paste it into. +// - Otherwise, the user will be redirected to callbackUrl in the browser, and +// will append a "oauth_verifier=<verifier>" parameter. +// +// This function returns: +// - rtoken: +// A temporary RequestToken, used during the authorization process. You must save +// this since it will be necessary later in the process when calling +// AuthorizeToken(). +// +// - url: +// A URL that you should redirect the user to in order that they may authorize you +// to the service provider. +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) GetRequestTokenAndUrl(callbackUrl string) (rtoken *RequestToken, loginUrl string, err error) { + return c.GetRequestTokenAndUrlWithParams(callbackUrl, c.AdditionalParams) +} + +func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) { + params := c.baseParams(c.consumerKey, additionalParams) + if callbackUrl != "" { + params.Add(CALLBACK_PARAM, callbackUrl) + } + + req := &request{ + method: c.serviceProvider.httpMethod(), + url: c.serviceProvider.RequestTokenUrl, + oauthParams: params, + } + if _, err := c.signRequest(req, ""); err != nil { // We don't have a token secret for the key yet + return nil, "", err + } + + resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params) + if err != nil { + return nil, "", errors.New("getBody: " + err.Error()) + } + + requestToken, err := parseRequestToken(*resp) + if err != nil { + return nil, "", errors.New("parseRequestToken: " + err.Error()) + } + + loginParams := make(url.Values) + for k, v := range c.AdditionalAuthorizationUrlParams { + loginParams.Set(k, v) + } + loginParams.Set(TOKEN_PARAM, requestToken.Token) + + loginUrl = c.serviceProvider.AuthorizeTokenUrl + "?" + loginParams.Encode() + + return requestToken, loginUrl, nil +} + +// After the user has authorized you to the service provider, use this method to turn +// your temporary RequestToken into a permanent AccessToken. You must pass in two values: +// - rtoken: +// The RequestToken returned from GetRequestTokenAndUrl() +// +// - verificationCode: +// The string which passed back from the server, either as the oauth_verifier +// query param appended to callbackUrl *OR* a string manually entered by the user +// if callbackUrl is "oob" +// +// It will return: +// - atoken: +// A permanent AccessToken which can be used to access the user's data (until it is +// revoked by the user or the service provider). +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) AuthorizeToken(rtoken *RequestToken, verificationCode string) (atoken *AccessToken, err error) { + return c.AuthorizeTokenWithParams(rtoken, verificationCode, c.AdditionalParams) +} + +func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) { + params := map[string]string{ + VERIFIER_PARAM: verificationCode, + TOKEN_PARAM: rtoken.Token, + } + return c.makeAccessTokenRequestWithParams(params, rtoken.Secret, additionalParams) +} + +// Use the service provider to refresh the AccessToken for a given session. +// Note that this is only supported for service providers that manage an +// authorization session (e.g. Yahoo). +// +// Most providers do not return the SESSION_HANDLE_PARAM needed to refresh +// the token. +// +// See http://oauth.googlecode.com/svn/spec/ext/session/1.0/drafts/1/spec.html +// for more information. +// - accessToken: +// The AccessToken returned from AuthorizeToken() +// +// It will return: +// - atoken: +// An AccessToken which can be used to access the user's data (until it is +// revoked by the user or the service provider). +// +// - err: +// Set if accessToken does not contain the SESSION_HANDLE_PARAM needed to +// refresh the token, or if an error occurred when making the request. +func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, err error) { + params := make(map[string]string) + sessionHandle, ok := accessToken.AdditionalData[SESSION_HANDLE_PARAM] + if !ok { + return nil, errors.New("Missing " + SESSION_HANDLE_PARAM + " in access token.") + } + params[SESSION_HANDLE_PARAM] = sessionHandle + params[TOKEN_PARAM] = accessToken.Token + + return c.makeAccessTokenRequest(params, accessToken.Secret) +} + +// Use the service provider to obtain an AccessToken for a given session +// - params: +// The access token request paramters. +// +// - secret: +// Secret key to use when signing the access token request. +// +// It will return: +// - atoken +// An AccessToken which can be used to access the user's data (until it is +// revoked by the user or the service provider). +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret string) (atoken *AccessToken, err error) { + return c.makeAccessTokenRequestWithParams(params, secret, c.AdditionalParams) +} + +func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) { + orderedParams := c.baseParams(c.consumerKey, additionalParams) + for key, value := range params { + orderedParams.Add(key, value) + } + + req := &request{ + method: c.serviceProvider.httpMethod(), + url: c.serviceProvider.AccessTokenUrl, + oauthParams: orderedParams, + } + if _, err := c.signRequest(req, secret); err != nil { + return nil, err + } + + resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams) + if err != nil { + return nil, err + } + + return parseAccessToken(*resp) +} + +type RoundTripper struct { + consumer *Consumer + token *AccessToken +} + +func (c *Consumer) MakeRoundTripper(token *AccessToken) (*RoundTripper, error) { + return &RoundTripper{consumer: c, token: token}, nil +} + +func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) { + return &http.Client{ + Transport: &RoundTripper{consumer: c, token: token}, + }, nil +} + +// ** DEPRECATED ** +// Please call Get on the http client returned by MakeHttpClient instead! +// +// Executes an HTTP Get, authorized via the AccessToken. +// - url: +// The base url, without any query params, which is being accessed +// +// - userParams: +// Any key=value params to be included in the query string +// +// - token: +// The AccessToken returned by AuthorizeToken() +// +// This method returns: +// - resp: +// The HTTP Response resulting from making this request. +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) Get(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("GET", url, LOC_URL, "", userParams, token) +} + +func encodeUserParams(userParams map[string]string) string { + data := url.Values{} + for k, v := range userParams { + data.Add(k, v) + } + return data.Encode() +} + +// ** DEPRECATED ** +// Please call "Post" on the http client returned by MakeHttpClient instead +func (c *Consumer) PostForm(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.PostWithBody(url, "", userParams, token) +} + +// ** DEPRECATED ** +// Please call "Post" on the http client returned by MakeHttpClient instead +func (c *Consumer) Post(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.PostWithBody(url, "", userParams, token) +} + +// ** DEPRECATED ** +// Please call "Post" on the http client returned by MakeHttpClient instead +func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token) +} + +// ** DEPRECATED ** +// Please call "Do" on the http client returned by MakeHttpClient instead +// (and set the "Content-Type" header explicitly in the http.Request) +func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token) +} + +// ** DEPRECATED ** +// Please call "Do" on the http client returned by MakeHttpClient instead +// (and set the "Content-Type" header explicitly in the http.Request) +func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token) +} + +// ** DEPRECATED ** +// Please call "Do" on the http client returned by MakeHttpClient instead +// (and setup the multipart data explicitly in the http.Request) +func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token) +} + +// ** DEPRECATED ** +// Please call "Delete" on the http client returned by MakeHttpClient instead +func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token) +} + +// ** DEPRECATED ** +// Please call "Put" on the http client returned by MakeHttpClient instead +func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token) +} + +func (c *Consumer) Debug(enabled bool) { + c.debug = enabled + c.signer.Debug(enabled) +} + +type pair struct { + key string + value string +} + +type pairs []pair + +func (p pairs) Len() int { return len(p) } +func (p pairs) Less(i, j int) bool { return p[i].key < p[j].key } +func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// This function has basically turned into a backwards compatibility layer +// between the old API (where clients explicitly called consumer.Get() +// consumer.Post() etc), and the new API (which takes actual http.Requests) +// +// So, here we construct the appropriate HTTP request for the inputs. +func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + urlObject, err := url.Parse(urlString) + if err != nil { + return nil, err + } + + request := &http.Request{ + Method: method, + URL: urlObject, + Header: http.Header{}, + Body: body, + ContentLength: int64(contentLength), + } + + vals := url.Values{} + for k, v := range userParams { + vals.Add(k, v) + } + + if dataLocation != LOC_BODY { + request.URL.RawQuery = vals.Encode() + request.URL.RawQuery = strings.Replace( + request.URL.RawQuery, ";", "%3B", -1) + + } else { + // TODO(mrjones): validate that we're not overrideing an exising body? + request.Body = ioutil.NopCloser(strings.NewReader(vals.Encode())) + request.ContentLength = int64(len(vals.Encode())) + } + + for k, vs := range c.AdditionalHeaders { + for _, v := range vs { + request.Header.Set(k, v) + } + } + + if dataLocation == LOC_BODY { + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + + if dataLocation == LOC_JSON { + request.Header.Set("Content-Type", "application/json") + } + + if dataLocation == LOC_XML { + request.Header.Set("Content-Type", "application/xml") + } + + if dataLocation == LOC_MULTIPART { + pipeReader, pipeWriter := io.Pipe() + writer := multipart.NewWriter(pipeWriter) + if request.URL.Host == "www.mrjon.es" && + request.URL.Path == "/unittest" { + writer.SetBoundary("UNITTESTBOUNDARY") + } + go func(body io.Reader) { + part, err := writer.CreateFormFile(multipartName, "/no/matter") + if err != nil { + writer.Close() + pipeWriter.CloseWithError(err) + return + } + _, err = io.Copy(part, body) + if err != nil { + writer.Close() + pipeWriter.CloseWithError(err) + return + } + writer.Close() + pipeWriter.Close() + }(body) + request.Body = pipeReader + request.Header.Set("Content-Type", writer.FormDataContentType()) + } + + rt := RoundTripper{consumer: c, token: token} + + resp, err = rt.RoundTrip(request) + if err != nil { + return resp, err + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + defer resp.Body.Close() + bytes, _ := ioutil.ReadAll(resp.Body) + + return resp, HTTPExecuteError{ + RequestHeaders: "", + ResponseBodyBytes: bytes, + Status: resp.Status, + StatusCode: resp.StatusCode, + } + } + + return resp, nil +} + +// cloneReq clones the src http.Request, making deep copies of the Header and +// the URL but shallow copies of everything else +func cloneReq(src *http.Request) *http.Request { + dst := &http.Request{} + *dst = *src + + dst.Header = make(http.Header, len(src.Header)) + for k, s := range src.Header { + dst.Header[k] = append([]string(nil), s...) + } + + if src.URL != nil { + dst.URL = cloneURL(src.URL) + } + + return dst +} + +// cloneURL shallow clones the src *url.URL +func cloneURL(src *url.URL) *url.URL { + dst := &url.URL{} + *dst = *src + + return dst +} + +func canonicalizeUrl(u *url.URL) string { + var buf bytes.Buffer + buf.WriteString(u.Scheme) + buf.WriteString("://") + buf.WriteString(u.Host) + buf.WriteString(u.Path) + + return buf.String() +} + +func parseBody(request *http.Request) (map[string]string, error) { + userParams := map[string]string{} + + // TODO(mrjones): factor parameter extraction into a separate method + if request.Header.Get("Content-Type") != + "application/x-www-form-urlencoded" { + // Most of the time we get parameters from the query string: + for k, vs := range request.URL.Query() { + if len(vs) != 1 { + return nil, fmt.Errorf("Must have exactly one value per param") + } + + userParams[k] = vs[0] + } + } else { + // x-www-form-urlencoded parameters come from the body instead: + defer request.Body.Close() + originalBody, err := ioutil.ReadAll(request.Body) + if err != nil { + return nil, err + } + + // If there was a body, we have to re-install it + // (because we've ruined it by reading it). + request.Body = ioutil.NopCloser(bytes.NewReader(originalBody)) + + params, err := url.ParseQuery(string(originalBody)) + if err != nil { + return nil, err + } + + for k, vs := range params { + if len(vs) != 1 { + return nil, fmt.Errorf("Must have exactly one value per param") + } + + userParams[k] = vs[0] + } + } + + return userParams, nil +} + +func paramsToSortedPairs(params map[string]string) pairs { + // Sort parameters alphabetically + paramPairs := make(pairs, len(params)) + i := 0 + for key, value := range params { + paramPairs[i] = pair{key: key, value: value} + i++ + } + sort.Sort(paramPairs) + + return paramPairs +} + +func calculateBodyHash(request *http.Request, s signer) (string, error) { + if request.Header.Get("Content-Type") == + "application/x-www-form-urlencoded" { + return "", nil + } + + var originalBody []byte + + if request.Body != nil { + var err error + + defer request.Body.Close() + originalBody, err = ioutil.ReadAll(request.Body) + if err != nil { + return "", err + } + + // If there was a body, we have to re-install it + // (because we've ruined it by reading it). + request.Body = ioutil.NopCloser(bytes.NewReader(originalBody)) + } + + h := s.HashFunc().New() + h.Write(originalBody) + rawSignature := h.Sum(nil) + + return base64.StdEncoding.EncodeToString(rawSignature), nil +} + +func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, error) { + serverRequest := cloneReq(userRequest) + + allParams := rt.consumer.baseParams( + rt.consumer.consumerKey, rt.consumer.AdditionalParams) + + // Do not add the "oauth_token" parameter, if the access token has not been + // specified. By omitting this parameter when it is not specified, allows + // two-legged OAuth calls. + if len(rt.token.Token) > 0 { + allParams.Add(TOKEN_PARAM, rt.token.Token) + } + + if rt.consumer.serviceProvider.BodyHash { + bodyHash, err := calculateBodyHash(serverRequest, rt.consumer.signer) + if err != nil { + return nil, err + } + + if bodyHash != "" { + allParams.Add(BODY_HASH_PARAM, bodyHash) + } + } + + authParams := allParams.Clone() + + // TODO(mrjones): put these directly into the paramPairs below? + userParams, err := parseBody(serverRequest) + if err != nil { + return nil, err + } + paramPairs := paramsToSortedPairs(userParams) + + for i := range paramPairs { + allParams.Add(paramPairs[i].key, paramPairs[i].value) + } + + signingURL := cloneURL(serverRequest.URL) + if host := serverRequest.Host; host != "" { + signingURL.Host = host + } + baseString := rt.consumer.requestString(serverRequest.Method, canonicalizeUrl(signingURL), allParams) + + signature, err := rt.consumer.signer.Sign(baseString, rt.token.Secret) + if err != nil { + return nil, err + } + + authParams.Add(SIGNATURE_PARAM, signature) + + // Set auth header. + oauthHdr := OAUTH_HEADER + for pos, key := range authParams.Keys() { + for innerPos, value := range authParams.Get(key) { + if pos+innerPos > 0 { + oauthHdr += "," + } + oauthHdr += key + "=\"" + value + "\"" + } + } + serverRequest.Header.Add(HTTP_AUTH_HEADER, oauthHdr) + + if rt.consumer.debug { + fmt.Printf("Request: %v\n", serverRequest) + } + + resp, err := rt.consumer.HttpClient.Do(serverRequest) + + if err != nil { + return resp, err + } + + return resp, nil +} + +func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequestReader(method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token) +} + +type request struct { + method string + url string + oauthParams *OrderedParams + userParams map[string]string +} + +type HttpClient interface { + Do(req *http.Request) (resp *http.Response, err error) +} + +type clock interface { + Seconds() int64 + Nanos() int64 +} + +type nonceGenerator interface { + Int63() int64 +} + +type key interface { + String() string +} + +type signer interface { + Sign(message string, tokenSecret string) (string, error) + Verify(message string, signature string) error + SignatureMethod() string + HashFunc() crypto.Hash + Debug(enabled bool) +} + +type defaultClock struct{} + +func (*defaultClock) Seconds() int64 { + return time.Now().Unix() +} + +func (*defaultClock) Nanos() int64 { + return time.Now().UnixNano() +} + +func (c *Consumer) signRequest(req *request, tokenSecret string) (*request, error) { + baseString := c.requestString(req.method, req.url, req.oauthParams) + + signature, err := c.signer.Sign(baseString, tokenSecret) + if err != nil { + return nil, err + } + + req.oauthParams.Add(SIGNATURE_PARAM, signature) + return req, nil +} + +// Obtains an AccessToken from the response of a service provider. +// - data: +// The response body. +// +// This method returns: +// - atoken: +// The AccessToken generated from the response body. +// +// - err: +// Set if an AccessToken could not be parsed from the given input. +func parseAccessToken(data string) (atoken *AccessToken, err error) { + parts, err := url.ParseQuery(data) + if err != nil { + return nil, err + } + + tokenParam := parts[TOKEN_PARAM] + parts.Del(TOKEN_PARAM) + if len(tokenParam) < 1 { + return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " + + "Full response body: '" + data + "'") + } + tokenSecretParam := parts[TOKEN_SECRET_PARAM] + parts.Del(TOKEN_SECRET_PARAM) + if len(tokenSecretParam) < 1 { + return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." + + "Full response body: '" + data + "'") + } + + additionalData := parseAdditionalData(parts) + + return &AccessToken{tokenParam[0], tokenSecretParam[0], additionalData}, nil +} + +func parseRequestToken(data string) (*RequestToken, error) { + parts, err := url.ParseQuery(data) + if err != nil { + return nil, err + } + + tokenParam := parts[TOKEN_PARAM] + if len(tokenParam) < 1 { + return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " + + "Full response body: '" + data + "'") + } + tokenSecretParam := parts[TOKEN_SECRET_PARAM] + if len(tokenSecretParam) < 1 { + return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." + + "Full response body: '" + data + "'") + } + return &RequestToken{tokenParam[0], tokenSecretParam[0]}, nil +} + +func (c *Consumer) baseParams(consumerKey string, additionalParams map[string]string) *OrderedParams { + params := NewOrderedParams() + params.Add(VERSION_PARAM, OAUTH_VERSION) + params.Add(SIGNATURE_METHOD_PARAM, c.signer.SignatureMethod()) + params.Add(TIMESTAMP_PARAM, strconv.FormatInt(c.clock.Seconds(), 10)) + params.Add(NONCE_PARAM, strconv.FormatInt(c.nonceGenerator.Int63(), 10)) + params.Add(CONSUMER_KEY_PARAM, consumerKey) + for key, value := range additionalParams { + params.Add(key, value) + } + return params +} + +func parseAdditionalData(parts url.Values) map[string]string { + params := make(map[string]string) + for key, value := range parts { + if len(value) > 0 { + params[key] = value[0] + } + } + return params +} + +type HMACSigner struct { + consumerSecret string + hashFunc crypto.Hash + debug bool +} + +func (s *HMACSigner) Debug(enabled bool) { + s.debug = enabled +} + +func (s *HMACSigner) Sign(message string, tokenSecret string) (string, error) { + key := escape(s.consumerSecret) + "&" + escape(tokenSecret) + if s.debug { + fmt.Println("Signing:", message) + fmt.Println("Key:", key) + } + + h := hmac.New(s.HashFunc().New, []byte(key)) + h.Write([]byte(message)) + rawSignature := h.Sum(nil) + + base64signature := base64.StdEncoding.EncodeToString(rawSignature) + if s.debug { + fmt.Println("Base64 signature:", base64signature) + } + return base64signature, nil +} + +func (s *HMACSigner) Verify(message string, signature string) error { + if s.debug { + fmt.Println("Verifying Base64 signature:", signature) + } + validSignature, err := s.Sign(message, "") + if err != nil { + return err + } + + if validSignature != signature { + decodedSigniture, _ := url.QueryUnescape(signature) + if validSignature != decodedSigniture { + return fmt.Errorf("signature did not match") + } + } + + return nil +} + +func (s *HMACSigner) SignatureMethod() string { + return SIGNATURE_METHOD_HMAC + HASH_METHOD_MAP[s.HashFunc()] +} + +func (s *HMACSigner) HashFunc() crypto.Hash { + return s.hashFunc +} + +type RSASigner struct { + debug bool + rand io.Reader + privateKey *rsa.PrivateKey + hashFunc crypto.Hash +} + +func (s *RSASigner) Debug(enabled bool) { + s.debug = enabled +} + +func (s *RSASigner) Sign(message string, tokenSecret string) (string, error) { + if s.debug { + fmt.Println("Signing:", message) + } + + h := s.HashFunc().New() + h.Write([]byte(message)) + digest := h.Sum(nil) + + signature, err := rsa.SignPKCS1v15(s.rand, s.privateKey, s.HashFunc(), digest) + if err != nil { + return "", nil + } + + base64signature := base64.StdEncoding.EncodeToString(signature) + if s.debug { + fmt.Println("Base64 signature:", base64signature) + } + + return base64signature, nil +} + +func (s *RSASigner) Verify(message string, base64signature string) error { + if s.debug { + fmt.Println("Verifying:", message) + fmt.Println("Verifying Base64 signature:", base64signature) + } + + h := s.HashFunc().New() + h.Write([]byte(message)) + digest := h.Sum(nil) + + signature, err := base64.StdEncoding.DecodeString(base64signature) + if err != nil { + return err + } + + return rsa.VerifyPKCS1v15(&s.privateKey.PublicKey, s.HashFunc(), digest, signature) +} + +func (s *RSASigner) SignatureMethod() string { + return SIGNATURE_METHOD_RSA + HASH_METHOD_MAP[s.HashFunc()] +} + +func (s *RSASigner) HashFunc() crypto.Hash { + return s.hashFunc +} + +func escape(s string) string { + t := make([]byte, 0, 3*len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if isEscapable(c) { + t = append(t, '%') + t = append(t, "0123456789ABCDEF"[c>>4]) + t = append(t, "0123456789ABCDEF"[c&15]) + } else { + t = append(t, s[i]) + } + } + return string(t) +} + +func isEscapable(b byte) bool { + return !('A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' || b == '-' || b == '.' || b == '_' || b == '~') + +} + +func (c *Consumer) requestString(method string, url string, params *OrderedParams) string { + result := method + "&" + escape(url) + for pos, key := range params.Keys() { + for innerPos, value := range params.Get(key) { + if pos+innerPos == 0 { + result += "&" + } else { + result += escape("&") + } + result += escape(fmt.Sprintf("%s=%s", key, value)) + } + } + return result +} + +func (c *Consumer) getBody(method, url string, oauthParams *OrderedParams) (*string, error) { + resp, err := c.httpExecute(method, url, "", 0, nil, oauthParams) + if err != nil { + return nil, errors.New("httpExecute: " + err.Error()) + } + bodyBytes, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, errors.New("ReadAll: " + err.Error()) + } + bodyStr := string(bodyBytes) + if c.debug { + fmt.Printf("STATUS: %d %s\n", resp.StatusCode, resp.Status) + fmt.Println("BODY RESPONSE: " + bodyStr) + } + return &bodyStr, nil +} + +// HTTPExecuteError signals that a call to httpExecute failed. +type HTTPExecuteError struct { + // RequestHeaders provides a stringified listing of request headers. + RequestHeaders string + // ResponseBodyBytes is the response read into a byte slice. + ResponseBodyBytes []byte + // Status is the status code string response. + Status string + // StatusCode is the parsed status code. + StatusCode int +} + +// Error provides a printable string description of an HTTPExecuteError. +func (e HTTPExecuteError) Error() string { + return "HTTP response is not 200/OK as expected. Actual response: \n" + + "\tResponse Status: '" + e.Status + "'\n" + + "\tResponse Code: " + strconv.Itoa(e.StatusCode) + "\n" + + "\tResponse Body: " + string(e.ResponseBodyBytes) + "\n" + + "\tRequest Headers: " + e.RequestHeaders +} + +func (c *Consumer) httpExecute( + method string, urlStr string, contentType string, contentLength int, body io.Reader, oauthParams *OrderedParams) (*http.Response, error) { + // Create base request. + req, err := http.NewRequest(method, urlStr, body) + if err != nil { + return nil, errors.New("NewRequest failed: " + err.Error()) + } + + // Set auth header. + req.Header = http.Header{} + oauthHdr := "OAuth " + for pos, key := range oauthParams.Keys() { + for innerPos, value := range oauthParams.Get(key) { + if pos+innerPos > 0 { + oauthHdr += "," + } + oauthHdr += key + "=\"" + value + "\"" + } + } + req.Header.Add("Authorization", oauthHdr) + + // Add additional custom headers + for key, vals := range c.AdditionalHeaders { + for _, val := range vals { + req.Header.Add(key, val) + } + } + + // Set contentType if passed. + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + + // Set contentLength if passed. + if contentLength > 0 { + req.Header.Set("Content-Length", strconv.Itoa(contentLength)) + } + + if c.debug { + fmt.Printf("Request: %v\n", req) + } + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, errors.New("Do: " + err.Error()) + } + + debugHeader := "" + for k, vals := range req.Header { + for _, val := range vals { + debugHeader += "[key: " + k + ", val: " + val + "]" + } + } + + // StatusMultipleChoices is 300, any 2xx response should be treated as success + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + defer resp.Body.Close() + bytes, _ := ioutil.ReadAll(resp.Body) + + return resp, HTTPExecuteError{ + RequestHeaders: debugHeader, + ResponseBodyBytes: bytes, + Status: resp.Status, + StatusCode: resp.StatusCode, + } + } + return resp, err +} + +// +// String Sorting helpers +// + +type ByValue []string + +func (a ByValue) Len() int { + return len(a) +} + +func (a ByValue) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} + +func (a ByValue) Less(i, j int) bool { + return a[i] < a[j] +} + +// +// ORDERED PARAMS +// + +type OrderedParams struct { + allParams map[string][]string + keyOrdering []string +} + +func NewOrderedParams() *OrderedParams { + return &OrderedParams{ + allParams: make(map[string][]string), + keyOrdering: make([]string, 0), + } +} + +func (o *OrderedParams) Get(key string) []string { + sort.Sort(ByValue(o.allParams[key])) + return o.allParams[key] +} + +func (o *OrderedParams) Keys() []string { + sort.Sort(o) + return o.keyOrdering +} + +func (o *OrderedParams) Add(key, value string) { + o.AddUnescaped(key, escape(value)) +} + +func (o *OrderedParams) AddUnescaped(key, value string) { + if _, exists := o.allParams[key]; !exists { + o.keyOrdering = append(o.keyOrdering, key) + o.allParams[key] = make([]string, 1) + o.allParams[key][0] = value + } else { + o.allParams[key] = append(o.allParams[key], value) + } +} + +func (o *OrderedParams) Len() int { + return len(o.keyOrdering) +} + +func (o *OrderedParams) Less(i int, j int) bool { + return o.keyOrdering[i] < o.keyOrdering[j] +} + +func (o *OrderedParams) Swap(i int, j int) { + o.keyOrdering[i], o.keyOrdering[j] = o.keyOrdering[j], o.keyOrdering[i] +} + +func (o *OrderedParams) Clone() *OrderedParams { + clone := NewOrderedParams() + for _, key := range o.Keys() { + for _, value := range o.Get(key) { + clone.AddUnescaped(key, value) + } + } + return clone +} |