diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2022-05-20 22:08:52 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-20 22:08:52 +0800 |
commit | fd7d83ace60258acf7139c4c787aa8af75b7ba8c (patch) | |
tree | 50038348ec10485f72344f3ac80324e04abc1283 /models/auth/oauth2.go | |
parent | d81e31ad7826a81fc7139f329f250594610a274b (diff) | |
download | gitea-fd7d83ace60258acf7139c4c787aa8af75b7ba8c.tar.gz gitea-fd7d83ace60258acf7139c4c787aa8af75b7ba8c.zip |
Move almost all functions' parameter db.Engine to context.Context (#19748)
* Move almost all functions' parameter db.Engine to context.Context
* remove some unnecessary wrap functions
Diffstat (limited to 'models/auth/oauth2.go')
-rw-r--r-- | models/auth/oauth2.go | 134 |
1 files changed, 37 insertions, 97 deletions
diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index ca77fcdb78..c5c6e91120 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool { } // GetGrantByUserID returns a OAuth2Grant by its user and application ID -func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) { - return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID) -} - -func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) { +func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) { grant = new(OAuth2Grant) - if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil { + if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil { return nil, err } else if !has { return nil, nil @@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant } // CreateGrant generates a grant for an user -func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) { - return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope) -} - -func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) { +func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) { grant := &OAuth2Grant{ ApplicationID: app.ID, UserID: userID, Scope: scope, } - _, err := e.Insert(grant) + err := db.Insert(ctx, grant) if err != nil { return nil, err } @@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin } // GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found. -func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) { - return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID) -} - -func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) { +func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) { app = new(OAuth2Application) - has, err := e.Where("client_id = ?", clientID).Get(app) + has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app) if !has { return nil, ErrOAuthClientIDInvalid{ClientID: clientID} } @@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap } // GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found. -func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) { - return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id) -} - -func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) { +func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) { app = new(OAuth2Application) - has, err := e.ID(id).Get(app) + has, err := db.GetEngine(ctx).ID(id).Get(app) if err != nil { return nil, err } @@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er } // GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user -func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) { - return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID) -} - -func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) { +func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) { apps = make([]*OAuth2Application, 0) - err = e.Where("uid = ?", userID).Find(&apps) + err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps) return } @@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct { } // CreateOAuth2Application inserts a new oauth2 application -func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { - return createOAuth2Application(db.GetEngine(db.DefaultContext), opts) -} - -func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { +func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { clientID := uuid.New().String() app := &OAuth2Application{ UID: opts.UserID, @@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) ( ClientID: clientID, RedirectURIs: opts.RedirectURIs, } - if _, err := e.Insert(app); err != nil { + if err := db.Insert(ctx, app); err != nil { return nil, err } return app, nil @@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) - app, err := getOAuth2ApplicationByID(sess, opts.ID) + app, err := GetOAuth2ApplicationByID(ctx, opts.ID) if err != nil { return nil, err } @@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic app.Name = opts.Name app.RedirectURIs = opts.RedirectURIs - if err = updateOAuth2Application(sess, app); err != nil { + if err = updateOAuth2Application(ctx, app); err != nil { return nil, err } app.ClientSecret = "" @@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic return app, committer.Commit() } -func updateOAuth2Application(e db.Engine, app *OAuth2Application) error { - if _, err := e.ID(app.ID).Update(app); err != nil { +func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error { + if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil { return err } return nil } -func deleteOAuth2Application(sess db.Engine, id, userid int64) error { +func deleteOAuth2Application(ctx context.Context, id, userid int64) error { + sess := db.GetEngine(ctx) if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil { return err } else if deleted == 0 { @@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error { return err } defer committer.Close() - if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil { + if err := deleteOAuth2Application(ctx, id, userid); err != nil { return err } return committer.Commit() @@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect } // Invalidate deletes the auth code from the database to invalidate this code -func (code *OAuth2AuthorizationCode) Invalidate() error { - return code.invalidate(db.GetEngine(db.DefaultContext)) -} - -func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error { - _, err := e.Delete(code) +func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code) return err } // ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation. func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool { - return code.validateCodeChallenge(verifier) -} - -func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool { switch code.CodeChallengeMethod { case "S256": // base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6 @@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool } // GetOAuth2AuthorizationByCode returns an authorization by its code -func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) { - return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code) -} - -func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) { +func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) { auth = new(OAuth2AuthorizationCode) - if has, err := e.Where("code = ?", code).Get(auth); err != nil { + if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil { return nil, err } else if !has { return nil, nil } auth.Grant = new(OAuth2Grant) - if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil { + if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil { return nil, err } else if !has { return nil, nil @@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string { } // GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database -func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) { - return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod) -} - -func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) { +func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) { rBytes, err := util.CryptoRandomBytes(32) if err != nil { return &OAuth2AuthorizationCode{}, err @@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, } - if _, err := e.Insert(code); err != nil { + if err := db.Insert(ctx, code); err != nil { return nil, err } return code, nil } // IncreaseCounter increases the counter and updates the grant -func (grant *OAuth2Grant) IncreaseCounter() error { - return grant.increaseCount(db.GetEngine(db.DefaultContext)) -} - -func (grant *OAuth2Grant) increaseCount(e db.Engine) error { - _, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant)) +func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant)) if err != nil { return err } - updatedGrant, err := getOAuth2GrantByID(e, grant.ID) + updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID) if err != nil { return err } @@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool { } // SetNonce updates the current nonce value of a grant -func (grant *OAuth2Grant) SetNonce(nonce string) error { - return grant.setNonce(db.GetEngine(db.DefaultContext), nonce) -} - -func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error { +func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error { grant.Nonce = nonce - _, err := e.ID(grant.ID).Cols("nonce").Update(grant) + _, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant) if err != nil { return err } @@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error { } // GetOAuth2GrantByID returns the grant with the given ID -func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) { - return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id) -} - -func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) { +func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) { grant = new(OAuth2Grant) - if has, err := e.ID(id).Get(grant); err != nil { + if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil { return nil, err } else if !has { return nil, nil @@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) { } // GetOAuth2GrantsByUserID lists all grants of a certain user -func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) { - return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid) -} - -func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) { +func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) { type joinedOAuth2Grant struct { Grant *OAuth2Grant `xorm:"extends"` Application *OAuth2Application `xorm:"extends"` } var results *xorm.Rows var err error - if results, err = e. + if results, err = db.GetEngine(ctx). Table("oauth2_grant"). Where("user_id = ?", uid). Join("INNER", "oauth2_application", "application_id = oauth2_application.id"). @@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) { } // RevokeOAuth2Grant deletes the grant with grantID and userID -func RevokeOAuth2Grant(grantID, userID int64) error { - return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID) -} - -func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error { - _, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID}) +func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error { + _, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID}) return err } |