aboutsummaryrefslogtreecommitdiffstats
path: root/routers/web/auth
diff options
context:
space:
mode:
authorLunny Xiao <xiaolunwen@gmail.com>2022-05-20 22:08:52 +0800
committerGitHub <noreply@github.com>2022-05-20 22:08:52 +0800
commitfd7d83ace60258acf7139c4c787aa8af75b7ba8c (patch)
tree50038348ec10485f72344f3ac80324e04abc1283 /routers/web/auth
parentd81e31ad7826a81fc7139f329f250594610a274b (diff)
downloadgitea-fd7d83ace60258acf7139c4c787aa8af75b7ba8c.tar.gz
gitea-fd7d83ace60258acf7139c4c787aa8af75b7ba8c.zip
Move almost all functions' parameter db.Engine to context.Context (#19748)
* Move almost all functions' parameter db.Engine to context.Context * remove some unnecessary wrap functions
Diffstat (limited to 'routers/web/auth')
-rw-r--r--routers/web/auth/auth.go2
-rw-r--r--routers/web/auth/linkaccount.go2
-rw-r--r--routers/web/auth/oauth.go39
-rw-r--r--routers/web/auth/oauth_test.go7
-rw-r--r--routers/web/auth/openid.go6
5 files changed, 29 insertions, 27 deletions
diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go
index be936d2230..4d5a2c9335 100644
--- a/routers/web/auth/auth.go
+++ b/routers/web/auth/auth.go
@@ -64,7 +64,7 @@ func AutoSignIn(ctx *context.Context) (bool, error) {
}
}()
- u, err := user_model.GetUserByName(uname)
+ u, err := user_model.GetUserByName(ctx, uname)
if err != nil {
if !user_model.IsErrUserNotExist(err) {
return false, fmt.Errorf("GetUserByName: %v", err)
diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go
index c3e96f077a..a2d76e9c5a 100644
--- a/routers/web/auth/linkaccount.go
+++ b/routers/web/auth/linkaccount.go
@@ -70,7 +70,7 @@ func LinkAccount(ctx *context.Context) {
ctx.Data["user_exists"] = true
}
} else if len(uname) != 0 {
- u, err := user_model.GetUserByName(uname)
+ u, err := user_model.GetUserByName(ctx, uname)
if err != nil && !user_model.IsErrUserNotExist(err) {
ctx.ServerError("UserSignIn", err)
return
diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go
index 4c3e3c3ace..9aa31c1c02 100644
--- a/routers/web/auth/oauth.go
+++ b/routers/web/auth/oauth.go
@@ -5,6 +5,7 @@
package auth
import (
+ stdContext "context"
"encoding/base64"
"errors"
"fmt"
@@ -135,9 +136,9 @@ type AccessTokenResponse struct {
IDToken string `json:"id_token,omitempty"`
}
-func newAccessTokenResponse(grant *auth.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
+func newAccessTokenResponse(ctx stdContext.Context, grant *auth.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
if setting.OAuth2.InvalidateRefreshTokens {
- if err := grant.IncreaseCounter(); err != nil {
+ if err := grant.IncreaseCounter(ctx); err != nil {
return nil, &AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidGrant,
ErrorDescription: "cannot increase the grant counter",
@@ -182,7 +183,7 @@ func newAccessTokenResponse(grant *auth.OAuth2Grant, serverKey, clientKey oauth2
// generate OpenID Connect id_token
signedIDToken := ""
if grant.ScopeContains("openid") {
- app, err := auth.GetOAuth2ApplicationByID(grant.ApplicationID)
+ app, err := auth.GetOAuth2ApplicationByID(ctx, grant.ApplicationID)
if err != nil {
return nil, &AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidRequest,
@@ -333,9 +334,9 @@ func IntrospectOAuth(ctx *context.Context) {
token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey)
if err == nil {
if token.Valid() == nil {
- grant, err := auth.GetOAuth2GrantByID(token.GrantID)
+ grant, err := auth.GetOAuth2GrantByID(ctx, token.GrantID)
if err == nil && grant != nil {
- app, err := auth.GetOAuth2ApplicationByID(grant.ApplicationID)
+ app, err := auth.GetOAuth2ApplicationByID(ctx, grant.ApplicationID)
if err == nil && app != nil {
response.Active = true
response.Scope = grant.Scope
@@ -364,7 +365,7 @@ func AuthorizeOAuth(ctx *context.Context) {
return
}
- app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID)
+ app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID)
if err != nil {
if auth.IsErrOauthClientIDInvalid(err) {
handleAuthorizeError(ctx, AuthorizeError{
@@ -438,7 +439,7 @@ func AuthorizeOAuth(ctx *context.Context) {
return
}
- grant, err := app.GetGrantByUserID(ctx.Doer.ID)
+ grant, err := app.GetGrantByUserID(ctx, ctx.Doer.ID)
if err != nil {
handleServerError(ctx, form.State, form.RedirectURI)
return
@@ -446,7 +447,7 @@ func AuthorizeOAuth(ctx *context.Context) {
// Redirect if user already granted access
if grant != nil {
- code, err := grant.GenerateNewAuthorizationCode(form.RedirectURI, form.CodeChallenge, form.CodeChallengeMethod)
+ code, err := grant.GenerateNewAuthorizationCode(ctx, form.RedirectURI, form.CodeChallenge, form.CodeChallengeMethod)
if err != nil {
handleServerError(ctx, form.State, form.RedirectURI)
return
@@ -458,7 +459,7 @@ func AuthorizeOAuth(ctx *context.Context) {
}
// Update nonce to reflect the new session
if len(form.Nonce) > 0 {
- err := grant.SetNonce(form.Nonce)
+ err := grant.SetNonce(ctx, form.Nonce)
if err != nil {
log.Error("Unable to update nonce: %v", err)
}
@@ -510,12 +511,12 @@ func GrantApplicationOAuth(ctx *context.Context) {
ctx.Error(http.StatusBadRequest)
return
}
- app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID)
+ app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID)
if err != nil {
ctx.ServerError("GetOAuth2ApplicationByClientID", err)
return
}
- grant, err := app.CreateGrant(ctx.Doer.ID, form.Scope)
+ grant, err := app.CreateGrant(ctx, ctx.Doer.ID, form.Scope)
if err != nil {
handleAuthorizeError(ctx, AuthorizeError{
State: form.State,
@@ -525,7 +526,7 @@ func GrantApplicationOAuth(ctx *context.Context) {
return
}
if len(form.Nonce) > 0 {
- err := grant.SetNonce(form.Nonce)
+ err := grant.SetNonce(ctx, form.Nonce)
if err != nil {
log.Error("Unable to update nonce: %v", err)
}
@@ -535,7 +536,7 @@ func GrantApplicationOAuth(ctx *context.Context) {
codeChallenge, _ = ctx.Session.Get("CodeChallenge").(string)
codeChallengeMethod, _ = ctx.Session.Get("CodeChallengeMethod").(string)
- code, err := grant.GenerateNewAuthorizationCode(form.RedirectURI, codeChallenge, codeChallengeMethod)
+ code, err := grant.GenerateNewAuthorizationCode(ctx, form.RedirectURI, codeChallenge, codeChallengeMethod)
if err != nil {
handleServerError(ctx, form.State, form.RedirectURI)
return
@@ -648,7 +649,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server
return
}
// get grant before increasing counter
- grant, err := auth.GetOAuth2GrantByID(token.GrantID)
+ grant, err := auth.GetOAuth2GrantByID(ctx, token.GrantID)
if err != nil || grant == nil {
handleAccessTokenError(ctx, AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidGrant,
@@ -666,7 +667,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server
log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID)
return
}
- accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey)
+ accessToken, tokenErr := newAccessTokenResponse(ctx, grant, serverKey, clientKey)
if tokenErr != nil {
handleAccessTokenError(ctx, *tokenErr)
return
@@ -675,7 +676,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server
}
func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
- app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID)
+ app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID)
if err != nil {
handleAccessTokenError(ctx, AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidClient,
@@ -697,7 +698,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s
})
return
}
- authorizationCode, err := auth.GetOAuth2AuthorizationByCode(form.Code)
+ authorizationCode, err := auth.GetOAuth2AuthorizationByCode(ctx, form.Code)
if err != nil || authorizationCode == nil {
handleAccessTokenError(ctx, AccessTokenError{
ErrorCode: AccessTokenErrorCodeUnauthorizedClient,
@@ -722,13 +723,13 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s
return
}
// remove token from database to deny duplicate usage
- if err := authorizationCode.Invalidate(); err != nil {
+ if err := authorizationCode.Invalidate(ctx); err != nil {
handleAccessTokenError(ctx, AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidRequest,
ErrorDescription: "cannot proceed your request",
})
}
- resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey)
+ resp, tokenErr := newAccessTokenResponse(ctx, authorizationCode.Grant, serverKey, clientKey)
if tokenErr != nil {
handleAccessTokenError(ctx, *tokenErr)
return
diff --git a/routers/web/auth/oauth_test.go b/routers/web/auth/oauth_test.go
index 669d7431fc..5a09a95105 100644
--- a/routers/web/auth/oauth_test.go
+++ b/routers/web/auth/oauth_test.go
@@ -8,6 +8,7 @@ import (
"testing"
"code.gitea.io/gitea/models/auth"
+ "code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/services/auth/source/oauth2"
@@ -21,7 +22,7 @@ func createAndParseToken(t *testing.T, grant *auth.OAuth2Grant) *oauth2.OIDCToke
assert.NoError(t, err)
assert.NotNil(t, signingKey)
- response, terr := newAccessTokenResponse(grant, signingKey, signingKey)
+ response, terr := newAccessTokenResponse(db.DefaultContext, grant, signingKey, signingKey)
assert.Nil(t, terr)
assert.NotNil(t, response)
@@ -43,7 +44,7 @@ func createAndParseToken(t *testing.T, grant *auth.OAuth2Grant) *oauth2.OIDCToke
func TestNewAccessTokenResponse_OIDCToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- grants, err := auth.GetOAuth2GrantsByUserID(3)
+ grants, err := auth.GetOAuth2GrantsByUserID(db.DefaultContext, 3)
assert.NoError(t, err)
assert.Len(t, grants, 1)
@@ -59,7 +60,7 @@ func TestNewAccessTokenResponse_OIDCToken(t *testing.T) {
assert.False(t, oidcToken.EmailVerified)
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 5}).(*user_model.User)
- grants, err = auth.GetOAuth2GrantsByUserID(user.ID)
+ grants, err = auth.GetOAuth2GrantsByUserID(db.DefaultContext, user.ID)
assert.NoError(t, err)
assert.Len(t, grants, 1)
diff --git a/routers/web/auth/openid.go b/routers/web/auth/openid.go
index 3012d8c5a5..32ae91da47 100644
--- a/routers/web/auth/openid.go
+++ b/routers/web/auth/openid.go
@@ -217,7 +217,7 @@ func signInOpenIDVerify(ctx *context.Context) {
}
if u == nil && nickname != "" {
- u, _ = user_model.GetUserByName(nickname)
+ u, _ = user_model.GetUserByName(ctx, nickname)
if err != nil {
if !user_model.IsErrUserNotExist(err) {
ctx.RenderWithErr(err.Error(), tplSignInOpenID, &forms.SignInOpenIDForm{
@@ -307,7 +307,7 @@ func ConnectOpenIDPost(ctx *context.Context) {
// add OpenID for the user
userOID := &user_model.UserOpenID{UID: u.ID, URI: oid}
- if err = user_model.AddUserOpenID(userOID); err != nil {
+ if err = user_model.AddUserOpenID(ctx, userOID); err != nil {
if user_model.IsErrOpenIDAlreadyUsed(err) {
ctx.RenderWithErr(ctx.Tr("form.openid_been_used", oid), tplConnectOID, &form)
return
@@ -434,7 +434,7 @@ func RegisterOpenIDPost(ctx *context.Context) {
// add OpenID for the user
userOID := &user_model.UserOpenID{UID: u.ID, URI: oid}
- if err = user_model.AddUserOpenID(userOID); err != nil {
+ if err = user_model.AddUserOpenID(ctx, userOID); err != nil {
if user_model.IsErrOpenIDAlreadyUsed(err) {
ctx.RenderWithErr(ctx.Tr("form.openid_been_used", oid), tplSignUpOID, &form)
return