From fd7d83ace60258acf7139c4c787aa8af75b7ba8c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 20 May 2022 22:08:52 +0800 Subject: Move almost all functions' parameter db.Engine to context.Context (#19748) * Move almost all functions' parameter db.Engine to context.Context * remove some unnecessary wrap functions --- routers/web/auth/auth.go | 2 +- routers/web/auth/linkaccount.go | 2 +- routers/web/auth/oauth.go | 39 ++++++++++++++++++++------------------- routers/web/auth/oauth_test.go | 7 ++++--- routers/web/auth/openid.go | 6 +++--- 5 files changed, 29 insertions(+), 27 deletions(-) (limited to 'routers/web/auth') diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index be936d2230..4d5a2c9335 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -64,7 +64,7 @@ func AutoSignIn(ctx *context.Context) (bool, error) { } }() - u, err := user_model.GetUserByName(uname) + u, err := user_model.GetUserByName(ctx, uname) if err != nil { if !user_model.IsErrUserNotExist(err) { return false, fmt.Errorf("GetUserByName: %v", err) diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index c3e96f077a..a2d76e9c5a 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -70,7 +70,7 @@ func LinkAccount(ctx *context.Context) { ctx.Data["user_exists"] = true } } else if len(uname) != 0 { - u, err := user_model.GetUserByName(uname) + u, err := user_model.GetUserByName(ctx, uname) if err != nil && !user_model.IsErrUserNotExist(err) { ctx.ServerError("UserSignIn", err) return diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 4c3e3c3ace..9aa31c1c02 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -5,6 +5,7 @@ package auth import ( + stdContext "context" "encoding/base64" "errors" "fmt" @@ -135,9 +136,9 @@ type AccessTokenResponse struct { IDToken string `json:"id_token,omitempty"` } -func newAccessTokenResponse(grant *auth.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { +func newAccessTokenResponse(ctx stdContext.Context, grant *auth.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { if setting.OAuth2.InvalidateRefreshTokens { - if err := grant.IncreaseCounter(); err != nil { + if err := grant.IncreaseCounter(ctx); err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidGrant, ErrorDescription: "cannot increase the grant counter", @@ -182,7 +183,7 @@ func newAccessTokenResponse(grant *auth.OAuth2Grant, serverKey, clientKey oauth2 // generate OpenID Connect id_token signedIDToken := "" if grant.ScopeContains("openid") { - app, err := auth.GetOAuth2ApplicationByID(grant.ApplicationID) + app, err := auth.GetOAuth2ApplicationByID(ctx, grant.ApplicationID) if err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, @@ -333,9 +334,9 @@ func IntrospectOAuth(ctx *context.Context) { token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey) if err == nil { if token.Valid() == nil { - grant, err := auth.GetOAuth2GrantByID(token.GrantID) + grant, err := auth.GetOAuth2GrantByID(ctx, token.GrantID) if err == nil && grant != nil { - app, err := auth.GetOAuth2ApplicationByID(grant.ApplicationID) + app, err := auth.GetOAuth2ApplicationByID(ctx, grant.ApplicationID) if err == nil && app != nil { response.Active = true response.Scope = grant.Scope @@ -364,7 +365,7 @@ func AuthorizeOAuth(ctx *context.Context) { return } - app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID) if err != nil { if auth.IsErrOauthClientIDInvalid(err) { handleAuthorizeError(ctx, AuthorizeError{ @@ -438,7 +439,7 @@ func AuthorizeOAuth(ctx *context.Context) { return } - grant, err := app.GetGrantByUserID(ctx.Doer.ID) + grant, err := app.GetGrantByUserID(ctx, ctx.Doer.ID) if err != nil { handleServerError(ctx, form.State, form.RedirectURI) return @@ -446,7 +447,7 @@ func AuthorizeOAuth(ctx *context.Context) { // Redirect if user already granted access if grant != nil { - code, err := grant.GenerateNewAuthorizationCode(form.RedirectURI, form.CodeChallenge, form.CodeChallengeMethod) + code, err := grant.GenerateNewAuthorizationCode(ctx, form.RedirectURI, form.CodeChallenge, form.CodeChallengeMethod) if err != nil { handleServerError(ctx, form.State, form.RedirectURI) return @@ -458,7 +459,7 @@ func AuthorizeOAuth(ctx *context.Context) { } // Update nonce to reflect the new session if len(form.Nonce) > 0 { - err := grant.SetNonce(form.Nonce) + err := grant.SetNonce(ctx, form.Nonce) if err != nil { log.Error("Unable to update nonce: %v", err) } @@ -510,12 +511,12 @@ func GrantApplicationOAuth(ctx *context.Context) { ctx.Error(http.StatusBadRequest) return } - app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID) if err != nil { ctx.ServerError("GetOAuth2ApplicationByClientID", err) return } - grant, err := app.CreateGrant(ctx.Doer.ID, form.Scope) + grant, err := app.CreateGrant(ctx, ctx.Doer.ID, form.Scope) if err != nil { handleAuthorizeError(ctx, AuthorizeError{ State: form.State, @@ -525,7 +526,7 @@ func GrantApplicationOAuth(ctx *context.Context) { return } if len(form.Nonce) > 0 { - err := grant.SetNonce(form.Nonce) + err := grant.SetNonce(ctx, form.Nonce) if err != nil { log.Error("Unable to update nonce: %v", err) } @@ -535,7 +536,7 @@ func GrantApplicationOAuth(ctx *context.Context) { codeChallenge, _ = ctx.Session.Get("CodeChallenge").(string) codeChallengeMethod, _ = ctx.Session.Get("CodeChallengeMethod").(string) - code, err := grant.GenerateNewAuthorizationCode(form.RedirectURI, codeChallenge, codeChallengeMethod) + code, err := grant.GenerateNewAuthorizationCode(ctx, form.RedirectURI, codeChallenge, codeChallengeMethod) if err != nil { handleServerError(ctx, form.State, form.RedirectURI) return @@ -648,7 +649,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server return } // get grant before increasing counter - grant, err := auth.GetOAuth2GrantByID(token.GrantID) + grant, err := auth.GetOAuth2GrantByID(ctx, token.GrantID) if err != nil || grant == nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidGrant, @@ -666,7 +667,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID) return } - accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey) + accessToken, tokenErr := newAccessTokenResponse(ctx, grant, serverKey, clientKey) if tokenErr != nil { handleAccessTokenError(ctx, *tokenErr) return @@ -675,7 +676,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server } func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) { - app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID) if err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidClient, @@ -697,7 +698,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s }) return } - authorizationCode, err := auth.GetOAuth2AuthorizationByCode(form.Code) + authorizationCode, err := auth.GetOAuth2AuthorizationByCode(ctx, form.Code) if err != nil || authorizationCode == nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeUnauthorizedClient, @@ -722,13 +723,13 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s return } // remove token from database to deny duplicate usage - if err := authorizationCode.Invalidate(); err != nil { + if err := authorizationCode.Invalidate(ctx); err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, ErrorDescription: "cannot proceed your request", }) } - resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey) + resp, tokenErr := newAccessTokenResponse(ctx, authorizationCode.Grant, serverKey, clientKey) if tokenErr != nil { handleAccessTokenError(ctx, *tokenErr) return diff --git a/routers/web/auth/oauth_test.go b/routers/web/auth/oauth_test.go index 669d7431fc..5a09a95105 100644 --- a/routers/web/auth/oauth_test.go +++ b/routers/web/auth/oauth_test.go @@ -8,6 +8,7 @@ import ( "testing" "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/services/auth/source/oauth2" @@ -21,7 +22,7 @@ func createAndParseToken(t *testing.T, grant *auth.OAuth2Grant) *oauth2.OIDCToke assert.NoError(t, err) assert.NotNil(t, signingKey) - response, terr := newAccessTokenResponse(grant, signingKey, signingKey) + response, terr := newAccessTokenResponse(db.DefaultContext, grant, signingKey, signingKey) assert.Nil(t, terr) assert.NotNil(t, response) @@ -43,7 +44,7 @@ func createAndParseToken(t *testing.T, grant *auth.OAuth2Grant) *oauth2.OIDCToke func TestNewAccessTokenResponse_OIDCToken(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - grants, err := auth.GetOAuth2GrantsByUserID(3) + grants, err := auth.GetOAuth2GrantsByUserID(db.DefaultContext, 3) assert.NoError(t, err) assert.Len(t, grants, 1) @@ -59,7 +60,7 @@ func TestNewAccessTokenResponse_OIDCToken(t *testing.T) { assert.False(t, oidcToken.EmailVerified) user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 5}).(*user_model.User) - grants, err = auth.GetOAuth2GrantsByUserID(user.ID) + grants, err = auth.GetOAuth2GrantsByUserID(db.DefaultContext, user.ID) assert.NoError(t, err) assert.Len(t, grants, 1) diff --git a/routers/web/auth/openid.go b/routers/web/auth/openid.go index 3012d8c5a5..32ae91da47 100644 --- a/routers/web/auth/openid.go +++ b/routers/web/auth/openid.go @@ -217,7 +217,7 @@ func signInOpenIDVerify(ctx *context.Context) { } if u == nil && nickname != "" { - u, _ = user_model.GetUserByName(nickname) + u, _ = user_model.GetUserByName(ctx, nickname) if err != nil { if !user_model.IsErrUserNotExist(err) { ctx.RenderWithErr(err.Error(), tplSignInOpenID, &forms.SignInOpenIDForm{ @@ -307,7 +307,7 @@ func ConnectOpenIDPost(ctx *context.Context) { // add OpenID for the user userOID := &user_model.UserOpenID{UID: u.ID, URI: oid} - if err = user_model.AddUserOpenID(userOID); err != nil { + if err = user_model.AddUserOpenID(ctx, userOID); err != nil { if user_model.IsErrOpenIDAlreadyUsed(err) { ctx.RenderWithErr(ctx.Tr("form.openid_been_used", oid), tplConnectOID, &form) return @@ -434,7 +434,7 @@ func RegisterOpenIDPost(ctx *context.Context) { // add OpenID for the user userOID := &user_model.UserOpenID{UID: u.ID, URI: oid} - if err = user_model.AddUserOpenID(userOID); err != nil { + if err = user_model.AddUserOpenID(ctx, userOID); err != nil { if user_model.IsErrOpenIDAlreadyUsed(err) { ctx.RenderWithErr(ctx.Tr("form.openid_been_used", oid), tplSignUpOID, &form) return -- cgit v1.2.3