summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/mrjones/oauth/oauth.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/mrjones/oauth/oauth.go')
-rw-r--r--vendor/github.com/mrjones/oauth/oauth.go1412
1 files changed, 1412 insertions, 0 deletions
diff --git a/vendor/github.com/mrjones/oauth/oauth.go b/vendor/github.com/mrjones/oauth/oauth.go
new file mode 100644
index 0000000000..95eee64abd
--- /dev/null
+++ b/vendor/github.com/mrjones/oauth/oauth.go
@@ -0,0 +1,1412 @@
+// OAuth 1.0 consumer implementation.
+// See http://www.oauth.net and RFC 5849
+//
+// There are typically three parties involved in an OAuth exchange:
+// (1) The "Service Provider" (e.g. Google, Twitter, NetFlix) who operates the
+// service where the data resides.
+// (2) The "End User" who owns that data, and wants to grant access to a third-party.
+// (3) That third-party who wants access to the data (after first being authorized by
+// the user). This third-party is referred to as the "Consumer" in OAuth
+// terminology.
+//
+// This library is designed to help implement the third-party consumer by handling the
+// low-level authentication tasks, and allowing for authenticated requests to the
+// service provider on behalf of the user.
+//
+// Caveats:
+// - Currently only supports HMAC and RSA signatures.
+// - Currently only supports SHA1 and SHA256 hashes.
+// - Currently only supports OAuth 1.0
+//
+// Overview of how to use this library:
+// (1) First create a new Consumer instance with the NewConsumer function
+// (2) Get a RequestToken, and "authorization url" from GetRequestTokenAndUrl()
+// (3) Save the RequestToken, you will need it again in step 6.
+// (4) Redirect the user to the "authorization url" from step 2, where they will
+// authorize your access to the service provider.
+// (5) Wait. You will be called back on the CallbackUrl that you provide, and you
+// will recieve a "verification code".
+// (6) Call AuthorizeToken() with the RequestToken from step 2 and the
+// "verification code" from step 5.
+// (7) You will get back an AccessToken. Save this for as long as you need access
+// to the user's data, and treat it like a password; it is a secret.
+// (8) You can now throw away the RequestToken from step 2, it is no longer
+// necessary.
+// (9) Call "MakeHttpClient" using the AccessToken from step 7 to get an
+// HTTP client which can access protected resources.
+package oauth
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/hmac"
+ cryptoRand "crypto/rand"
+ "crypto/rsa"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "mime/multipart"
+ "net/http"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+const (
+ OAUTH_VERSION = "1.0"
+ SIGNATURE_METHOD_HMAC = "HMAC-"
+ SIGNATURE_METHOD_RSA = "RSA-"
+
+ HTTP_AUTH_HEADER = "Authorization"
+ OAUTH_HEADER = "OAuth "
+ BODY_HASH_PARAM = "oauth_body_hash"
+ CALLBACK_PARAM = "oauth_callback"
+ CONSUMER_KEY_PARAM = "oauth_consumer_key"
+ NONCE_PARAM = "oauth_nonce"
+ SESSION_HANDLE_PARAM = "oauth_session_handle"
+ SIGNATURE_METHOD_PARAM = "oauth_signature_method"
+ SIGNATURE_PARAM = "oauth_signature"
+ TIMESTAMP_PARAM = "oauth_timestamp"
+ TOKEN_PARAM = "oauth_token"
+ TOKEN_SECRET_PARAM = "oauth_token_secret"
+ VERIFIER_PARAM = "oauth_verifier"
+ VERSION_PARAM = "oauth_version"
+)
+
+var HASH_METHOD_MAP = map[crypto.Hash]string{
+ crypto.SHA1: "SHA1",
+ crypto.SHA256: "SHA256",
+}
+
+// TODO(mrjones) Do we definitely want separate "Request" and "Access" token classes?
+// They're identical structurally, but used for different purposes.
+type RequestToken struct {
+ Token string
+ Secret string
+}
+
+type AccessToken struct {
+ Token string
+ Secret string
+ AdditionalData map[string]string
+}
+
+type DataLocation int
+
+const (
+ LOC_BODY DataLocation = iota + 1
+ LOC_URL
+ LOC_MULTIPART
+ LOC_JSON
+ LOC_XML
+)
+
+// Information about how to contact the service provider (see #1 above).
+// You usually find all of these URLs by reading the documentation for the service
+// that you're trying to connect to.
+// Some common examples are:
+// (1) Google, standard APIs:
+// http://code.google.com/apis/accounts/docs/OAuth_ref.html
+// - RequestTokenUrl: https://www.google.com/accounts/OAuthGetRequestToken
+// - AuthorizeTokenUrl: https://www.google.com/accounts/OAuthAuthorizeToken
+// - AccessTokenUrl: https://www.google.com/accounts/OAuthGetAccessToken
+// Note: Some Google APIs (for example, Google Latitude) use different values for
+// one or more of those URLs.
+// (2) Twitter API:
+// http://dev.twitter.com/pages/auth
+// - RequestTokenUrl: http://api.twitter.com/oauth/request_token
+// - AuthorizeTokenUrl: https://api.twitter.com/oauth/authorize
+// - AccessTokenUrl: https://api.twitter.com/oauth/access_token
+// (3) NetFlix API:
+// http://developer.netflix.com/docs/Security
+// - RequestTokenUrl: http://api.netflix.com/oauth/request_token
+// - AuthroizeTokenUrl: https://api-user.netflix.com/oauth/login
+// - AccessTokenUrl: http://api.netflix.com/oauth/access_token
+// Set HttpMethod if the service provider requires a different HTTP method
+// to be used for OAuth token requests
+type ServiceProvider struct {
+ RequestTokenUrl string
+ AuthorizeTokenUrl string
+ AccessTokenUrl string
+ HttpMethod string
+ BodyHash bool
+ IgnoreTimestamp bool
+
+ // Enables non spec-compliant behavior:
+ // Allow parameters to be passed in the query string rather
+ // than the body.
+ // See https://github.com/mrjones/oauth/pull/63
+ SignQueryParams bool
+}
+
+func (sp *ServiceProvider) httpMethod() string {
+ if sp.HttpMethod != "" {
+ return sp.HttpMethod
+ }
+
+ return "GET"
+}
+
+// lockedNonceGenerator wraps a non-reentrant random number generator with a
+// lock
+type lockedNonceGenerator struct {
+ nonceGenerator nonceGenerator
+ lock sync.Mutex
+}
+
+func newLockedNonceGenerator(c clock) *lockedNonceGenerator {
+ return &lockedNonceGenerator{
+ nonceGenerator: rand.New(rand.NewSource(c.Nanos())),
+ }
+}
+
+func (n *lockedNonceGenerator) Int63() int64 {
+ n.lock.Lock()
+ r := n.nonceGenerator.Int63()
+ n.lock.Unlock()
+ return r
+}
+
+// Consumers are stateless, you can call the various methods (GetRequestTokenAndUrl,
+// AuthorizeToken, and Get) on various different instances of Consumers *as long as
+// they were set up in the same way.* It is up to you, as the caller to persist the
+// necessary state (RequestTokens and AccessTokens).
+type Consumer struct {
+ // Some ServiceProviders require extra parameters to be passed for various reasons.
+ // For example Google APIs require you to set a scope= parameter to specify how much
+ // access is being granted. The proper values for scope= depend on the service:
+ // For more, see: http://code.google.com/apis/accounts/docs/OAuth.html#prepScope
+ AdditionalParams map[string]string
+
+ // The rest of this class is configured via the NewConsumer function.
+ consumerKey string
+ serviceProvider ServiceProvider
+
+ // Some APIs (e.g. Netflix) aren't quite standard OAuth, and require passing
+ // additional parameters when authorizing the request token. For most APIs
+ // this field can be ignored. For Netflix, do something like:
+ // consumer.AdditionalAuthorizationUrlParams = map[string]string{
+ // "application_name": "YourAppName",
+ // "oauth_consumer_key": "YourConsumerKey",
+ // }
+ AdditionalAuthorizationUrlParams map[string]string
+
+ debug bool
+
+ // Defaults to http.Client{}, can be overridden (e.g. for testing) as necessary
+ HttpClient HttpClient
+
+ // Some APIs (e.g. Intuit/Quickbooks) require sending additional headers along with
+ // requests. (like "Accept" to specify the response type as XML or JSON) Note that this
+ // will only *add* headers, not set existing ones.
+ AdditionalHeaders map[string][]string
+
+ // Private seams for mocking dependencies when testing
+ clock clock
+ // Seeded generators are not reentrant
+ nonceGenerator nonceGenerator
+ signer signer
+}
+
+func newConsumer(consumerKey string, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer {
+ clock := &defaultClock{}
+ if httpClient == nil {
+ httpClient = &http.Client{}
+ }
+ return &Consumer{
+ consumerKey: consumerKey,
+ serviceProvider: serviceProvider,
+ clock: clock,
+ HttpClient: httpClient,
+ nonceGenerator: newLockedNonceGenerator(clock),
+
+ AdditionalParams: make(map[string]string),
+ AdditionalAuthorizationUrlParams: make(map[string]string),
+ }
+}
+
+// Creates a new Consumer instance, with a HMAC-SHA1 signer
+// - consumerKey and consumerSecret:
+// values you should obtain from the ServiceProvider when you register your
+// application.
+//
+// - serviceProvider:
+// see the documentation for ServiceProvider for how to create this.
+//
+func NewConsumer(consumerKey string, consumerSecret string,
+ serviceProvider ServiceProvider) *Consumer {
+ consumer := newConsumer(consumerKey, serviceProvider, nil)
+
+ consumer.signer = &HMACSigner{
+ consumerSecret: consumerSecret,
+ hashFunc: crypto.SHA1,
+ }
+
+ return consumer
+}
+
+// Creates a new Consumer instance, with a HMAC-SHA1 signer
+// - consumerKey and consumerSecret:
+// values you should obtain from the ServiceProvider when you register your
+// application.
+//
+// - serviceProvider:
+// see the documentation for ServiceProvider for how to create this.
+//
+// - httpClient:
+// Provides a custom implementation of the httpClient used under the hood
+// to make the request. This is especially useful if you want to use
+// Google App Engine.
+//
+func NewCustomHttpClientConsumer(consumerKey string, consumerSecret string,
+ serviceProvider ServiceProvider, httpClient *http.Client) *Consumer {
+ consumer := newConsumer(consumerKey, serviceProvider, httpClient)
+
+ consumer.signer = &HMACSigner{
+ consumerSecret: consumerSecret,
+ hashFunc: crypto.SHA1,
+ }
+
+ return consumer
+}
+
+// Creates a new Consumer instance, with a HMAC signer
+// - consumerKey and consumerSecret:
+// values you should obtain from the ServiceProvider when you register your
+// application.
+//
+// - hashFunc:
+// the crypto.Hash to use for signatures
+//
+// - serviceProvider:
+// see the documentation for ServiceProvider for how to create this.
+//
+// - httpClient:
+// Provides a custom implementation of the httpClient used under the hood
+// to make the request. This is especially useful if you want to use
+// Google App Engine. Can be nil for default.
+//
+func NewCustomConsumer(consumerKey string, consumerSecret string,
+ hashFunc crypto.Hash, serviceProvider ServiceProvider,
+ httpClient *http.Client) *Consumer {
+ consumer := newConsumer(consumerKey, serviceProvider, httpClient)
+
+ consumer.signer = &HMACSigner{
+ consumerSecret: consumerSecret,
+ hashFunc: hashFunc,
+ }
+
+ return consumer
+}
+
+// Creates a new Consumer instance, with a RSA-SHA1 signer
+// - consumerKey:
+// value you should obtain from the ServiceProvider when you register your
+// application.
+//
+// - privateKey:
+// the private key to use for signatures
+//
+// - serviceProvider:
+// see the documentation for ServiceProvider for how to create this.
+//
+func NewRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey,
+ serviceProvider ServiceProvider) *Consumer {
+ consumer := newConsumer(consumerKey, serviceProvider, nil)
+
+ consumer.signer = &RSASigner{
+ privateKey: privateKey,
+ hashFunc: crypto.SHA1,
+ rand: cryptoRand.Reader,
+ }
+
+ return consumer
+}
+
+// Creates a new Consumer instance, with a RSA signer
+// - consumerKey:
+// value you should obtain from the ServiceProvider when you register your
+// application.
+//
+// - privateKey:
+// the private key to use for signatures
+//
+// - hashFunc:
+// the crypto.Hash to use for signatures
+//
+// - serviceProvider:
+// see the documentation for ServiceProvider for how to create this.
+//
+// - httpClient:
+// Provides a custom implementation of the httpClient used under the hood
+// to make the request. This is especially useful if you want to use
+// Google App Engine. Can be nil for default.
+//
+func NewCustomRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey,
+ hashFunc crypto.Hash, serviceProvider ServiceProvider,
+ httpClient *http.Client) *Consumer {
+ consumer := newConsumer(consumerKey, serviceProvider, httpClient)
+
+ consumer.signer = &RSASigner{
+ privateKey: privateKey,
+ hashFunc: hashFunc,
+ rand: cryptoRand.Reader,
+ }
+
+ return consumer
+}
+
+// Kicks off the OAuth authorization process.
+// - callbackUrl:
+// Authorizing a token *requires* redirecting to the service provider. This is the
+// URL which the service provider will redirect the user back to after that
+// authorization is completed. The service provider will pass back a verification
+// code which is necessary to complete the rest of the process (in AuthorizeToken).
+// Notes on callbackUrl:
+// - Some (all?) service providers allow for setting "oob" (for out-of-band) as a
+// callback url. If this is set the service provider will present the
+// verification code directly to the user, and you must provide a place for
+// them to copy-and-paste it into.
+// - Otherwise, the user will be redirected to callbackUrl in the browser, and
+// will append a "oauth_verifier=<verifier>" parameter.
+//
+// This function returns:
+// - rtoken:
+// A temporary RequestToken, used during the authorization process. You must save
+// this since it will be necessary later in the process when calling
+// AuthorizeToken().
+//
+// - url:
+// A URL that you should redirect the user to in order that they may authorize you
+// to the service provider.
+//
+// - err:
+// Set only if there was an error, nil otherwise.
+func (c *Consumer) GetRequestTokenAndUrl(callbackUrl string) (rtoken *RequestToken, loginUrl string, err error) {
+ return c.GetRequestTokenAndUrlWithParams(callbackUrl, c.AdditionalParams)
+}
+
+func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) {
+ params := c.baseParams(c.consumerKey, additionalParams)
+ if callbackUrl != "" {
+ params.Add(CALLBACK_PARAM, callbackUrl)
+ }
+
+ req := &request{
+ method: c.serviceProvider.httpMethod(),
+ url: c.serviceProvider.RequestTokenUrl,
+ oauthParams: params,
+ }
+ if _, err := c.signRequest(req, ""); err != nil { // We don't have a token secret for the key yet
+ return nil, "", err
+ }
+
+ resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params)
+ if err != nil {
+ return nil, "", errors.New("getBody: " + err.Error())
+ }
+
+ requestToken, err := parseRequestToken(*resp)
+ if err != nil {
+ return nil, "", errors.New("parseRequestToken: " + err.Error())
+ }
+
+ loginParams := make(url.Values)
+ for k, v := range c.AdditionalAuthorizationUrlParams {
+ loginParams.Set(k, v)
+ }
+ loginParams.Set(TOKEN_PARAM, requestToken.Token)
+
+ loginUrl = c.serviceProvider.AuthorizeTokenUrl + "?" + loginParams.Encode()
+
+ return requestToken, loginUrl, nil
+}
+
+// After the user has authorized you to the service provider, use this method to turn
+// your temporary RequestToken into a permanent AccessToken. You must pass in two values:
+// - rtoken:
+// The RequestToken returned from GetRequestTokenAndUrl()
+//
+// - verificationCode:
+// The string which passed back from the server, either as the oauth_verifier
+// query param appended to callbackUrl *OR* a string manually entered by the user
+// if callbackUrl is "oob"
+//
+// It will return:
+// - atoken:
+// A permanent AccessToken which can be used to access the user's data (until it is
+// revoked by the user or the service provider).
+//
+// - err:
+// Set only if there was an error, nil otherwise.
+func (c *Consumer) AuthorizeToken(rtoken *RequestToken, verificationCode string) (atoken *AccessToken, err error) {
+ return c.AuthorizeTokenWithParams(rtoken, verificationCode, c.AdditionalParams)
+}
+
+func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) {
+ params := map[string]string{
+ VERIFIER_PARAM: verificationCode,
+ TOKEN_PARAM: rtoken.Token,
+ }
+ return c.makeAccessTokenRequestWithParams(params, rtoken.Secret, additionalParams)
+}
+
+// Use the service provider to refresh the AccessToken for a given session.
+// Note that this is only supported for service providers that manage an
+// authorization session (e.g. Yahoo).
+//
+// Most providers do not return the SESSION_HANDLE_PARAM needed to refresh
+// the token.
+//
+// See http://oauth.googlecode.com/svn/spec/ext/session/1.0/drafts/1/spec.html
+// for more information.
+// - accessToken:
+// The AccessToken returned from AuthorizeToken()
+//
+// It will return:
+// - atoken:
+// An AccessToken which can be used to access the user's data (until it is
+// revoked by the user or the service provider).
+//
+// - err:
+// Set if accessToken does not contain the SESSION_HANDLE_PARAM needed to
+// refresh the token, or if an error occurred when making the request.
+func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, err error) {
+ params := make(map[string]string)
+ sessionHandle, ok := accessToken.AdditionalData[SESSION_HANDLE_PARAM]
+ if !ok {
+ return nil, errors.New("Missing " + SESSION_HANDLE_PARAM + " in access token.")
+ }
+ params[SESSION_HANDLE_PARAM] = sessionHandle
+ params[TOKEN_PARAM] = accessToken.Token
+
+ return c.makeAccessTokenRequest(params, accessToken.Secret)
+}
+
+// Use the service provider to obtain an AccessToken for a given session
+// - params:
+// The access token request paramters.
+//
+// - secret:
+// Secret key to use when signing the access token request.
+//
+// It will return:
+// - atoken
+// An AccessToken which can be used to access the user's data (until it is
+// revoked by the user or the service provider).
+//
+// - err:
+// Set only if there was an error, nil otherwise.
+func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret string) (atoken *AccessToken, err error) {
+ return c.makeAccessTokenRequestWithParams(params, secret, c.AdditionalParams)
+}
+
+func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) {
+ orderedParams := c.baseParams(c.consumerKey, additionalParams)
+ for key, value := range params {
+ orderedParams.Add(key, value)
+ }
+
+ req := &request{
+ method: c.serviceProvider.httpMethod(),
+ url: c.serviceProvider.AccessTokenUrl,
+ oauthParams: orderedParams,
+ }
+ if _, err := c.signRequest(req, secret); err != nil {
+ return nil, err
+ }
+
+ resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams)
+ if err != nil {
+ return nil, err
+ }
+
+ return parseAccessToken(*resp)
+}
+
+type RoundTripper struct {
+ consumer *Consumer
+ token *AccessToken
+}
+
+func (c *Consumer) MakeRoundTripper(token *AccessToken) (*RoundTripper, error) {
+ return &RoundTripper{consumer: c, token: token}, nil
+}
+
+func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) {
+ return &http.Client{
+ Transport: &RoundTripper{consumer: c, token: token},
+ }, nil
+}
+
+// ** DEPRECATED **
+// Please call Get on the http client returned by MakeHttpClient instead!
+//
+// Executes an HTTP Get, authorized via the AccessToken.
+// - url:
+// The base url, without any query params, which is being accessed
+//
+// - userParams:
+// Any key=value params to be included in the query string
+//
+// - token:
+// The AccessToken returned by AuthorizeToken()
+//
+// This method returns:
+// - resp:
+// The HTTP Response resulting from making this request.
+//
+// - err:
+// Set only if there was an error, nil otherwise.
+func (c *Consumer) Get(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequest("GET", url, LOC_URL, "", userParams, token)
+}
+
+func encodeUserParams(userParams map[string]string) string {
+ data := url.Values{}
+ for k, v := range userParams {
+ data.Add(k, v)
+ }
+ return data.Encode()
+}
+
+// ** DEPRECATED **
+// Please call "Post" on the http client returned by MakeHttpClient instead
+func (c *Consumer) PostForm(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.PostWithBody(url, "", userParams, token)
+}
+
+// ** DEPRECATED **
+// Please call "Post" on the http client returned by MakeHttpClient instead
+func (c *Consumer) Post(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.PostWithBody(url, "", userParams, token)
+}
+
+// ** DEPRECATED **
+// Please call "Post" on the http client returned by MakeHttpClient instead
+func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token)
+}
+
+// ** DEPRECATED **
+// Please call "Do" on the http client returned by MakeHttpClient instead
+// (and set the "Content-Type" header explicitly in the http.Request)
+func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token)
+}
+
+// ** DEPRECATED **
+// Please call "Do" on the http client returned by MakeHttpClient instead
+// (and set the "Content-Type" header explicitly in the http.Request)
+func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token)
+}
+
+// ** DEPRECATED **
+// Please call "Do" on the http client returned by MakeHttpClient instead
+// (and setup the multipart data explicitly in the http.Request)
+func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token)
+}
+
+// ** DEPRECATED **
+// Please call "Delete" on the http client returned by MakeHttpClient instead
+func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token)
+}
+
+// ** DEPRECATED **
+// Please call "Put" on the http client returned by MakeHttpClient instead
+func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token)
+}
+
+func (c *Consumer) Debug(enabled bool) {
+ c.debug = enabled
+ c.signer.Debug(enabled)
+}
+
+type pair struct {
+ key string
+ value string
+}
+
+type pairs []pair
+
+func (p pairs) Len() int { return len(p) }
+func (p pairs) Less(i, j int) bool { return p[i].key < p[j].key }
+func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+
+// This function has basically turned into a backwards compatibility layer
+// between the old API (where clients explicitly called consumer.Get()
+// consumer.Post() etc), and the new API (which takes actual http.Requests)
+//
+// So, here we construct the appropriate HTTP request for the inputs.
+func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ urlObject, err := url.Parse(urlString)
+ if err != nil {
+ return nil, err
+ }
+
+ request := &http.Request{
+ Method: method,
+ URL: urlObject,
+ Header: http.Header{},
+ Body: body,
+ ContentLength: int64(contentLength),
+ }
+
+ vals := url.Values{}
+ for k, v := range userParams {
+ vals.Add(k, v)
+ }
+
+ if dataLocation != LOC_BODY {
+ request.URL.RawQuery = vals.Encode()
+ request.URL.RawQuery = strings.Replace(
+ request.URL.RawQuery, ";", "%3B", -1)
+
+ } else {
+ // TODO(mrjones): validate that we're not overrideing an exising body?
+ request.Body = ioutil.NopCloser(strings.NewReader(vals.Encode()))
+ request.ContentLength = int64(len(vals.Encode()))
+ }
+
+ for k, vs := range c.AdditionalHeaders {
+ for _, v := range vs {
+ request.Header.Set(k, v)
+ }
+ }
+
+ if dataLocation == LOC_BODY {
+ request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ }
+
+ if dataLocation == LOC_JSON {
+ request.Header.Set("Content-Type", "application/json")
+ }
+
+ if dataLocation == LOC_XML {
+ request.Header.Set("Content-Type", "application/xml")
+ }
+
+ if dataLocation == LOC_MULTIPART {
+ pipeReader, pipeWriter := io.Pipe()
+ writer := multipart.NewWriter(pipeWriter)
+ if request.URL.Host == "www.mrjon.es" &&
+ request.URL.Path == "/unittest" {
+ writer.SetBoundary("UNITTESTBOUNDARY")
+ }
+ go func(body io.Reader) {
+ part, err := writer.CreateFormFile(multipartName, "/no/matter")
+ if err != nil {
+ writer.Close()
+ pipeWriter.CloseWithError(err)
+ return
+ }
+ _, err = io.Copy(part, body)
+ if err != nil {
+ writer.Close()
+ pipeWriter.CloseWithError(err)
+ return
+ }
+ writer.Close()
+ pipeWriter.Close()
+ }(body)
+ request.Body = pipeReader
+ request.Header.Set("Content-Type", writer.FormDataContentType())
+ }
+
+ rt := RoundTripper{consumer: c, token: token}
+
+ resp, err = rt.RoundTrip(request)
+ if err != nil {
+ return resp, err
+ }
+
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ defer resp.Body.Close()
+ bytes, _ := ioutil.ReadAll(resp.Body)
+
+ return resp, HTTPExecuteError{
+ RequestHeaders: "",
+ ResponseBodyBytes: bytes,
+ Status: resp.Status,
+ StatusCode: resp.StatusCode,
+ }
+ }
+
+ return resp, nil
+}
+
+// cloneReq clones the src http.Request, making deep copies of the Header and
+// the URL but shallow copies of everything else
+func cloneReq(src *http.Request) *http.Request {
+ dst := &http.Request{}
+ *dst = *src
+
+ dst.Header = make(http.Header, len(src.Header))
+ for k, s := range src.Header {
+ dst.Header[k] = append([]string(nil), s...)
+ }
+
+ if src.URL != nil {
+ dst.URL = cloneURL(src.URL)
+ }
+
+ return dst
+}
+
+// cloneURL shallow clones the src *url.URL
+func cloneURL(src *url.URL) *url.URL {
+ dst := &url.URL{}
+ *dst = *src
+
+ return dst
+}
+
+func canonicalizeUrl(u *url.URL) string {
+ var buf bytes.Buffer
+ buf.WriteString(u.Scheme)
+ buf.WriteString("://")
+ buf.WriteString(u.Host)
+ buf.WriteString(u.Path)
+
+ return buf.String()
+}
+
+func parseBody(request *http.Request) (map[string]string, error) {
+ userParams := map[string]string{}
+
+ // TODO(mrjones): factor parameter extraction into a separate method
+ if request.Header.Get("Content-Type") !=
+ "application/x-www-form-urlencoded" {
+ // Most of the time we get parameters from the query string:
+ for k, vs := range request.URL.Query() {
+ if len(vs) != 1 {
+ return nil, fmt.Errorf("Must have exactly one value per param")
+ }
+
+ userParams[k] = vs[0]
+ }
+ } else {
+ // x-www-form-urlencoded parameters come from the body instead:
+ defer request.Body.Close()
+ originalBody, err := ioutil.ReadAll(request.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // If there was a body, we have to re-install it
+ // (because we've ruined it by reading it).
+ request.Body = ioutil.NopCloser(bytes.NewReader(originalBody))
+
+ params, err := url.ParseQuery(string(originalBody))
+ if err != nil {
+ return nil, err
+ }
+
+ for k, vs := range params {
+ if len(vs) != 1 {
+ return nil, fmt.Errorf("Must have exactly one value per param")
+ }
+
+ userParams[k] = vs[0]
+ }
+ }
+
+ return userParams, nil
+}
+
+func paramsToSortedPairs(params map[string]string) pairs {
+ // Sort parameters alphabetically
+ paramPairs := make(pairs, len(params))
+ i := 0
+ for key, value := range params {
+ paramPairs[i] = pair{key: key, value: value}
+ i++
+ }
+ sort.Sort(paramPairs)
+
+ return paramPairs
+}
+
+func calculateBodyHash(request *http.Request, s signer) (string, error) {
+ if request.Header.Get("Content-Type") ==
+ "application/x-www-form-urlencoded" {
+ return "", nil
+ }
+
+ var originalBody []byte
+
+ if request.Body != nil {
+ var err error
+
+ defer request.Body.Close()
+ originalBody, err = ioutil.ReadAll(request.Body)
+ if err != nil {
+ return "", err
+ }
+
+ // If there was a body, we have to re-install it
+ // (because we've ruined it by reading it).
+ request.Body = ioutil.NopCloser(bytes.NewReader(originalBody))
+ }
+
+ h := s.HashFunc().New()
+ h.Write(originalBody)
+ rawSignature := h.Sum(nil)
+
+ return base64.StdEncoding.EncodeToString(rawSignature), nil
+}
+
+func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, error) {
+ serverRequest := cloneReq(userRequest)
+
+ allParams := rt.consumer.baseParams(
+ rt.consumer.consumerKey, rt.consumer.AdditionalParams)
+
+ // Do not add the "oauth_token" parameter, if the access token has not been
+ // specified. By omitting this parameter when it is not specified, allows
+ // two-legged OAuth calls.
+ if len(rt.token.Token) > 0 {
+ allParams.Add(TOKEN_PARAM, rt.token.Token)
+ }
+
+ if rt.consumer.serviceProvider.BodyHash {
+ bodyHash, err := calculateBodyHash(serverRequest, rt.consumer.signer)
+ if err != nil {
+ return nil, err
+ }
+
+ if bodyHash != "" {
+ allParams.Add(BODY_HASH_PARAM, bodyHash)
+ }
+ }
+
+ authParams := allParams.Clone()
+
+ // TODO(mrjones): put these directly into the paramPairs below?
+ userParams, err := parseBody(serverRequest)
+ if err != nil {
+ return nil, err
+ }
+ paramPairs := paramsToSortedPairs(userParams)
+
+ for i := range paramPairs {
+ allParams.Add(paramPairs[i].key, paramPairs[i].value)
+ }
+
+ signingURL := cloneURL(serverRequest.URL)
+ if host := serverRequest.Host; host != "" {
+ signingURL.Host = host
+ }
+ baseString := rt.consumer.requestString(serverRequest.Method, canonicalizeUrl(signingURL), allParams)
+
+ signature, err := rt.consumer.signer.Sign(baseString, rt.token.Secret)
+ if err != nil {
+ return nil, err
+ }
+
+ authParams.Add(SIGNATURE_PARAM, signature)
+
+ // Set auth header.
+ oauthHdr := OAUTH_HEADER
+ for pos, key := range authParams.Keys() {
+ for innerPos, value := range authParams.Get(key) {
+ if pos+innerPos > 0 {
+ oauthHdr += ","
+ }
+ oauthHdr += key + "=\"" + value + "\""
+ }
+ }
+ serverRequest.Header.Add(HTTP_AUTH_HEADER, oauthHdr)
+
+ if rt.consumer.debug {
+ fmt.Printf("Request: %v\n", serverRequest)
+ }
+
+ resp, err := rt.consumer.HttpClient.Do(serverRequest)
+
+ if err != nil {
+ return resp, err
+ }
+
+ return resp, nil
+}
+
+func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
+ return c.makeAuthorizedRequestReader(method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token)
+}
+
+type request struct {
+ method string
+ url string
+ oauthParams *OrderedParams
+ userParams map[string]string
+}
+
+type HttpClient interface {
+ Do(req *http.Request) (resp *http.Response, err error)
+}
+
+type clock interface {
+ Seconds() int64
+ Nanos() int64
+}
+
+type nonceGenerator interface {
+ Int63() int64
+}
+
+type key interface {
+ String() string
+}
+
+type signer interface {
+ Sign(message string, tokenSecret string) (string, error)
+ Verify(message string, signature string) error
+ SignatureMethod() string
+ HashFunc() crypto.Hash
+ Debug(enabled bool)
+}
+
+type defaultClock struct{}
+
+func (*defaultClock) Seconds() int64 {
+ return time.Now().Unix()
+}
+
+func (*defaultClock) Nanos() int64 {
+ return time.Now().UnixNano()
+}
+
+func (c *Consumer) signRequest(req *request, tokenSecret string) (*request, error) {
+ baseString := c.requestString(req.method, req.url, req.oauthParams)
+
+ signature, err := c.signer.Sign(baseString, tokenSecret)
+ if err != nil {
+ return nil, err
+ }
+
+ req.oauthParams.Add(SIGNATURE_PARAM, signature)
+ return req, nil
+}
+
+// Obtains an AccessToken from the response of a service provider.
+// - data:
+// The response body.
+//
+// This method returns:
+// - atoken:
+// The AccessToken generated from the response body.
+//
+// - err:
+// Set if an AccessToken could not be parsed from the given input.
+func parseAccessToken(data string) (atoken *AccessToken, err error) {
+ parts, err := url.ParseQuery(data)
+ if err != nil {
+ return nil, err
+ }
+
+ tokenParam := parts[TOKEN_PARAM]
+ parts.Del(TOKEN_PARAM)
+ if len(tokenParam) < 1 {
+ return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " +
+ "Full response body: '" + data + "'")
+ }
+ tokenSecretParam := parts[TOKEN_SECRET_PARAM]
+ parts.Del(TOKEN_SECRET_PARAM)
+ if len(tokenSecretParam) < 1 {
+ return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." +
+ "Full response body: '" + data + "'")
+ }
+
+ additionalData := parseAdditionalData(parts)
+
+ return &AccessToken{tokenParam[0], tokenSecretParam[0], additionalData}, nil
+}
+
+func parseRequestToken(data string) (*RequestToken, error) {
+ parts, err := url.ParseQuery(data)
+ if err != nil {
+ return nil, err
+ }
+
+ tokenParam := parts[TOKEN_PARAM]
+ if len(tokenParam) < 1 {
+ return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " +
+ "Full response body: '" + data + "'")
+ }
+ tokenSecretParam := parts[TOKEN_SECRET_PARAM]
+ if len(tokenSecretParam) < 1 {
+ return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." +
+ "Full response body: '" + data + "'")
+ }
+ return &RequestToken{tokenParam[0], tokenSecretParam[0]}, nil
+}
+
+func (c *Consumer) baseParams(consumerKey string, additionalParams map[string]string) *OrderedParams {
+ params := NewOrderedParams()
+ params.Add(VERSION_PARAM, OAUTH_VERSION)
+ params.Add(SIGNATURE_METHOD_PARAM, c.signer.SignatureMethod())
+ params.Add(TIMESTAMP_PARAM, strconv.FormatInt(c.clock.Seconds(), 10))
+ params.Add(NONCE_PARAM, strconv.FormatInt(c.nonceGenerator.Int63(), 10))
+ params.Add(CONSUMER_KEY_PARAM, consumerKey)
+ for key, value := range additionalParams {
+ params.Add(key, value)
+ }
+ return params
+}
+
+func parseAdditionalData(parts url.Values) map[string]string {
+ params := make(map[string]string)
+ for key, value := range parts {
+ if len(value) > 0 {
+ params[key] = value[0]
+ }
+ }
+ return params
+}
+
+type HMACSigner struct {
+ consumerSecret string
+ hashFunc crypto.Hash
+ debug bool
+}
+
+func (s *HMACSigner) Debug(enabled bool) {
+ s.debug = enabled
+}
+
+func (s *HMACSigner) Sign(message string, tokenSecret string) (string, error) {
+ key := escape(s.consumerSecret) + "&" + escape(tokenSecret)
+ if s.debug {
+ fmt.Println("Signing:", message)
+ fmt.Println("Key:", key)
+ }
+
+ h := hmac.New(s.HashFunc().New, []byte(key))
+ h.Write([]byte(message))
+ rawSignature := h.Sum(nil)
+
+ base64signature := base64.StdEncoding.EncodeToString(rawSignature)
+ if s.debug {
+ fmt.Println("Base64 signature:", base64signature)
+ }
+ return base64signature, nil
+}
+
+func (s *HMACSigner) Verify(message string, signature string) error {
+ if s.debug {
+ fmt.Println("Verifying Base64 signature:", signature)
+ }
+ validSignature, err := s.Sign(message, "")
+ if err != nil {
+ return err
+ }
+
+ if validSignature != signature {
+ decodedSigniture, _ := url.QueryUnescape(signature)
+ if validSignature != decodedSigniture {
+ return fmt.Errorf("signature did not match")
+ }
+ }
+
+ return nil
+}
+
+func (s *HMACSigner) SignatureMethod() string {
+ return SIGNATURE_METHOD_HMAC + HASH_METHOD_MAP[s.HashFunc()]
+}
+
+func (s *HMACSigner) HashFunc() crypto.Hash {
+ return s.hashFunc
+}
+
+type RSASigner struct {
+ debug bool
+ rand io.Reader
+ privateKey *rsa.PrivateKey
+ hashFunc crypto.Hash
+}
+
+func (s *RSASigner) Debug(enabled bool) {
+ s.debug = enabled
+}
+
+func (s *RSASigner) Sign(message string, tokenSecret string) (string, error) {
+ if s.debug {
+ fmt.Println("Signing:", message)
+ }
+
+ h := s.HashFunc().New()
+ h.Write([]byte(message))
+ digest := h.Sum(nil)
+
+ signature, err := rsa.SignPKCS1v15(s.rand, s.privateKey, s.HashFunc(), digest)
+ if err != nil {
+ return "", nil
+ }
+
+ base64signature := base64.StdEncoding.EncodeToString(signature)
+ if s.debug {
+ fmt.Println("Base64 signature:", base64signature)
+ }
+
+ return base64signature, nil
+}
+
+func (s *RSASigner) Verify(message string, base64signature string) error {
+ if s.debug {
+ fmt.Println("Verifying:", message)
+ fmt.Println("Verifying Base64 signature:", base64signature)
+ }
+
+ h := s.HashFunc().New()
+ h.Write([]byte(message))
+ digest := h.Sum(nil)
+
+ signature, err := base64.StdEncoding.DecodeString(base64signature)
+ if err != nil {
+ return err
+ }
+
+ return rsa.VerifyPKCS1v15(&s.privateKey.PublicKey, s.HashFunc(), digest, signature)
+}
+
+func (s *RSASigner) SignatureMethod() string {
+ return SIGNATURE_METHOD_RSA + HASH_METHOD_MAP[s.HashFunc()]
+}
+
+func (s *RSASigner) HashFunc() crypto.Hash {
+ return s.hashFunc
+}
+
+func escape(s string) string {
+ t := make([]byte, 0, 3*len(s))
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if isEscapable(c) {
+ t = append(t, '%')
+ t = append(t, "0123456789ABCDEF"[c>>4])
+ t = append(t, "0123456789ABCDEF"[c&15])
+ } else {
+ t = append(t, s[i])
+ }
+ }
+ return string(t)
+}
+
+func isEscapable(b byte) bool {
+ return !('A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' || b == '-' || b == '.' || b == '_' || b == '~')
+
+}
+
+func (c *Consumer) requestString(method string, url string, params *OrderedParams) string {
+ result := method + "&" + escape(url)
+ for pos, key := range params.Keys() {
+ for innerPos, value := range params.Get(key) {
+ if pos+innerPos == 0 {
+ result += "&"
+ } else {
+ result += escape("&")
+ }
+ result += escape(fmt.Sprintf("%s=%s", key, value))
+ }
+ }
+ return result
+}
+
+func (c *Consumer) getBody(method, url string, oauthParams *OrderedParams) (*string, error) {
+ resp, err := c.httpExecute(method, url, "", 0, nil, oauthParams)
+ if err != nil {
+ return nil, errors.New("httpExecute: " + err.Error())
+ }
+ bodyBytes, err := ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ return nil, errors.New("ReadAll: " + err.Error())
+ }
+ bodyStr := string(bodyBytes)
+ if c.debug {
+ fmt.Printf("STATUS: %d %s\n", resp.StatusCode, resp.Status)
+ fmt.Println("BODY RESPONSE: " + bodyStr)
+ }
+ return &bodyStr, nil
+}
+
+// HTTPExecuteError signals that a call to httpExecute failed.
+type HTTPExecuteError struct {
+ // RequestHeaders provides a stringified listing of request headers.
+ RequestHeaders string
+ // ResponseBodyBytes is the response read into a byte slice.
+ ResponseBodyBytes []byte
+ // Status is the status code string response.
+ Status string
+ // StatusCode is the parsed status code.
+ StatusCode int
+}
+
+// Error provides a printable string description of an HTTPExecuteError.
+func (e HTTPExecuteError) Error() string {
+ return "HTTP response is not 200/OK as expected. Actual response: \n" +
+ "\tResponse Status: '" + e.Status + "'\n" +
+ "\tResponse Code: " + strconv.Itoa(e.StatusCode) + "\n" +
+ "\tResponse Body: " + string(e.ResponseBodyBytes) + "\n" +
+ "\tRequest Headers: " + e.RequestHeaders
+}
+
+func (c *Consumer) httpExecute(
+ method string, urlStr string, contentType string, contentLength int, body io.Reader, oauthParams *OrderedParams) (*http.Response, error) {
+ // Create base request.
+ req, err := http.NewRequest(method, urlStr, body)
+ if err != nil {
+ return nil, errors.New("NewRequest failed: " + err.Error())
+ }
+
+ // Set auth header.
+ req.Header = http.Header{}
+ oauthHdr := "OAuth "
+ for pos, key := range oauthParams.Keys() {
+ for innerPos, value := range oauthParams.Get(key) {
+ if pos+innerPos > 0 {
+ oauthHdr += ","
+ }
+ oauthHdr += key + "=\"" + value + "\""
+ }
+ }
+ req.Header.Add("Authorization", oauthHdr)
+
+ // Add additional custom headers
+ for key, vals := range c.AdditionalHeaders {
+ for _, val := range vals {
+ req.Header.Add(key, val)
+ }
+ }
+
+ // Set contentType if passed.
+ if contentType != "" {
+ req.Header.Set("Content-Type", contentType)
+ }
+
+ // Set contentLength if passed.
+ if contentLength > 0 {
+ req.Header.Set("Content-Length", strconv.Itoa(contentLength))
+ }
+
+ if c.debug {
+ fmt.Printf("Request: %v\n", req)
+ }
+ resp, err := c.HttpClient.Do(req)
+ if err != nil {
+ return nil, errors.New("Do: " + err.Error())
+ }
+
+ debugHeader := ""
+ for k, vals := range req.Header {
+ for _, val := range vals {
+ debugHeader += "[key: " + k + ", val: " + val + "]"
+ }
+ }
+
+ // StatusMultipleChoices is 300, any 2xx response should be treated as success
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ defer resp.Body.Close()
+ bytes, _ := ioutil.ReadAll(resp.Body)
+
+ return resp, HTTPExecuteError{
+ RequestHeaders: debugHeader,
+ ResponseBodyBytes: bytes,
+ Status: resp.Status,
+ StatusCode: resp.StatusCode,
+ }
+ }
+ return resp, err
+}
+
+//
+// String Sorting helpers
+//
+
+type ByValue []string
+
+func (a ByValue) Len() int {
+ return len(a)
+}
+
+func (a ByValue) Swap(i, j int) {
+ a[i], a[j] = a[j], a[i]
+}
+
+func (a ByValue) Less(i, j int) bool {
+ return a[i] < a[j]
+}
+
+//
+// ORDERED PARAMS
+//
+
+type OrderedParams struct {
+ allParams map[string][]string
+ keyOrdering []string
+}
+
+func NewOrderedParams() *OrderedParams {
+ return &OrderedParams{
+ allParams: make(map[string][]string),
+ keyOrdering: make([]string, 0),
+ }
+}
+
+func (o *OrderedParams) Get(key string) []string {
+ sort.Sort(ByValue(o.allParams[key]))
+ return o.allParams[key]
+}
+
+func (o *OrderedParams) Keys() []string {
+ sort.Sort(o)
+ return o.keyOrdering
+}
+
+func (o *OrderedParams) Add(key, value string) {
+ o.AddUnescaped(key, escape(value))
+}
+
+func (o *OrderedParams) AddUnescaped(key, value string) {
+ if _, exists := o.allParams[key]; !exists {
+ o.keyOrdering = append(o.keyOrdering, key)
+ o.allParams[key] = make([]string, 1)
+ o.allParams[key][0] = value
+ } else {
+ o.allParams[key] = append(o.allParams[key], value)
+ }
+}
+
+func (o *OrderedParams) Len() int {
+ return len(o.keyOrdering)
+}
+
+func (o *OrderedParams) Less(i int, j int) bool {
+ return o.keyOrdering[i] < o.keyOrdering[j]
+}
+
+func (o *OrderedParams) Swap(i int, j int) {
+ o.keyOrdering[i], o.keyOrdering[j] = o.keyOrdering[j], o.keyOrdering[i]
+}
+
+func (o *OrderedParams) Clone() *OrderedParams {
+ clone := NewOrderedParams()
+ for _, key := range o.Keys() {
+ for _, value := range o.Get(key) {
+ clone.AddUnescaped(key, value)
+ }
+ }
+ return clone
+}