]> source.dussan.org Git - gitea.git/commitdiff
Decoupled code from DefaultSigningKey (#16743)
authorKN4CK3R <admin@oldschoolhack.me>
Fri, 27 Aug 2021 19:28:00 +0000 (21:28 +0200)
committerGitHub <noreply@github.com>
Fri, 27 Aug 2021 19:28:00 +0000 (20:28 +0100)
Decoupled code from `DefaultSigningKey`. Makes testing a little bit easier and is cleaner.

routers/web/user/oauth.go
routers/web/user/oauth_test.go
services/auth/oauth2.go
services/auth/source/oauth2/token.go

index e29826630a37bc391d5efb810f866f62c312c283..cec6a92bbea45290a3227b5ea32930fcdcfec3e6 100644 (file)
@@ -115,7 +115,7 @@ type AccessTokenResponse struct {
        IDToken      string    `json:"id_token,omitempty"`
 }
 
-func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
+func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
        if setting.OAuth2.InvalidateRefreshTokens {
                if err := grant.IncreaseCounter(); err != nil {
                        return nil, &AccessTokenError{
@@ -133,7 +133,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
                        ExpiresAt: expirationDate.AsTime().Unix(),
                },
        }
-       signedAccessToken, err := accessToken.SignToken()
+       signedAccessToken, err := accessToken.SignToken(serverKey)
        if err != nil {
                return nil, &AccessTokenError{
                        ErrorCode:        AccessTokenErrorCodeInvalidRequest,
@@ -151,7 +151,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
                        ExpiresAt: refreshExpirationDate,
                },
        }
-       signedRefreshToken, err := refreshToken.SignToken()
+       signedRefreshToken, err := refreshToken.SignToken(serverKey)
        if err != nil {
                return nil, &AccessTokenError{
                        ErrorCode:        AccessTokenErrorCodeInvalidRequest,
@@ -207,7 +207,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
                        idToken.EmailVerified = user.IsActive
                }
 
-               signedIDToken, err = idToken.SignToken(signingKey)
+               signedIDToken, err = idToken.SignToken(clientKey)
                if err != nil {
                        return nil, &AccessTokenError{
                                ErrorCode:        AccessTokenErrorCodeInvalidRequest,
@@ -265,7 +265,7 @@ func IntrospectOAuth(ctx *context.Context) {
        }
 
        form := web.GetForm(ctx).(*forms.IntrospectTokenForm)
-       token, err := oauth2.ParseToken(form.Token)
+       token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey)
        if err == nil {
                if token.Valid() == nil {
                        grant, err := models.GetOAuth2GrantByID(token.GrantID)
@@ -544,9 +544,11 @@ func AccessTokenOAuth(ctx *context.Context) {
                }
        }
 
-       signingKey := oauth2.DefaultSigningKey
-       if signingKey.IsSymmetric() {
-               clientKey, err := oauth2.CreateJWTSigningKey(signingKey.SigningMethod().Alg(), []byte(form.ClientSecret))
+       serverKey := oauth2.DefaultSigningKey
+       clientKey := serverKey
+       if serverKey.IsSymmetric() {
+               var err error
+               clientKey, err = oauth2.CreateJWTSigningKey(serverKey.SigningMethod().Alg(), []byte(form.ClientSecret))
                if err != nil {
                        handleAccessTokenError(ctx, AccessTokenError{
                                ErrorCode:        AccessTokenErrorCodeInvalidRequest,
@@ -554,14 +556,13 @@ func AccessTokenOAuth(ctx *context.Context) {
                        })
                        return
                }
-               signingKey = clientKey
        }
 
        switch form.GrantType {
        case "refresh_token":
-               handleRefreshToken(ctx, form, signingKey)
+               handleRefreshToken(ctx, form, serverKey, clientKey)
        case "authorization_code":
-               handleAuthorizationCode(ctx, form, signingKey)
+               handleAuthorizationCode(ctx, form, serverKey, clientKey)
        default:
                handleAccessTokenError(ctx, AccessTokenError{
                        ErrorCode:        AccessTokenErrorCodeUnsupportedGrantType,
@@ -570,8 +571,8 @@ func AccessTokenOAuth(ctx *context.Context) {
        }
 }
 
-func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) {
-       token, err := oauth2.ParseToken(form.RefreshToken)
+func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
+       token, err := oauth2.ParseToken(form.RefreshToken, serverKey)
        if err != nil {
                handleAccessTokenError(ctx, AccessTokenError{
                        ErrorCode:        AccessTokenErrorCodeUnauthorizedClient,
@@ -598,7 +599,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin
                log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID)
                return
        }
-       accessToken, tokenErr := newAccessTokenResponse(grant, signingKey)
+       accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey)
        if tokenErr != nil {
                handleAccessTokenError(ctx, *tokenErr)
                return
@@ -606,7 +607,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin
        ctx.JSON(http.StatusOK, accessToken)
 }
 
-func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) {
+func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
        app, err := models.GetOAuth2ApplicationByClientID(form.ClientID)
        if err != nil {
                handleAccessTokenError(ctx, AccessTokenError{
@@ -660,7 +661,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s
                        ErrorDescription: "cannot proceed your request",
                })
        }
-       resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, signingKey)
+       resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey)
        if tokenErr != nil {
                handleAccessTokenError(ctx, *tokenErr)
                return
index c2f9ec87b5694afc1ccaf439da1cecc0db4eddd6..40116d3c1297984fafbbb23159d40f041af858dd 100644 (file)
@@ -18,9 +18,8 @@ func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCTo
        signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32))
        assert.NoError(t, err)
        assert.NotNil(t, signingKey)
-       oauth2.DefaultSigningKey = signingKey
 
-       response, terr := newAccessTokenResponse(grant, signingKey)
+       response, terr := newAccessTokenResponse(grant, signingKey, signingKey)
        assert.Nil(t, terr)
        assert.NotNil(t, response)
 
index f7f870dade14df9186bfbcde13f6f33deeb703b5..665e5232ccbed8240b141c8ac473dd4041c01e8a 100644 (file)
@@ -29,9 +29,9 @@ func CheckOAuthAccessToken(accessToken string) int64 {
        if !strings.Contains(accessToken, ".") {
                return 0
        }
-       token, err := oauth2.ParseToken(accessToken)
+       token, err := oauth2.ParseToken(accessToken, oauth2.DefaultSigningKey)
        if err != nil {
-               log.Trace("ParseOAuth2Token: %v", err)
+               log.Trace("oauth2.ParseToken: %v", err)
                return 0
        }
        var grant *models.OAuth2Grant
index 529e04577d43b6585ddf27b0e916d5df9b9dd408..16d1220842d3ab49640cb0bd4c22232a1cb9bdeb 100644 (file)
@@ -40,12 +40,12 @@ type Token struct {
 }
 
 // ParseToken parses a signed jwt string
-func ParseToken(jwtToken string) (*Token, error) {
+func ParseToken(jwtToken string, signingKey JWTSigningKey) (*Token, error) {
        parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) {
-               if token.Method == nil || token.Method.Alg() != DefaultSigningKey.SigningMethod().Alg() {
+               if token.Method == nil || token.Method.Alg() != signingKey.SigningMethod().Alg() {
                        return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"])
                }
-               return DefaultSigningKey.VerifyKey(), nil
+               return signingKey.VerifyKey(), nil
        })
        if err != nil {
                return nil, err
@@ -59,11 +59,11 @@ func ParseToken(jwtToken string) (*Token, error) {
 }
 
 // SignToken signs the token with the JWT secret
-func (token *Token) SignToken() (string, error) {
+func (token *Token) SignToken(signingKey JWTSigningKey) (string, error) {
        token.IssuedAt = time.Now().Unix()
-       jwtToken := jwt.NewWithClaims(DefaultSigningKey.SigningMethod(), token)
-       DefaultSigningKey.PreProcessToken(jwtToken)
-       return jwtToken.SignedString(DefaultSigningKey.SignKey())
+       jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token)
+       signingKey.PreProcessToken(jwtToken)
+       return jwtToken.SignedString(signingKey.SignKey())
 }
 
 // OIDCToken represents an OpenID Connect id_token