aboutsummaryrefslogtreecommitdiffstats
path: root/models/auth/oauth2.go
diff options
context:
space:
mode:
authorLunny Xiao <xiaolunwen@gmail.com>2022-05-20 22:08:52 +0800
committerGitHub <noreply@github.com>2022-05-20 22:08:52 +0800
commitfd7d83ace60258acf7139c4c787aa8af75b7ba8c (patch)
tree50038348ec10485f72344f3ac80324e04abc1283 /models/auth/oauth2.go
parentd81e31ad7826a81fc7139f329f250594610a274b (diff)
downloadgitea-fd7d83ace60258acf7139c4c787aa8af75b7ba8c.tar.gz
gitea-fd7d83ace60258acf7139c4c787aa8af75b7ba8c.zip
Move almost all functions' parameter db.Engine to context.Context (#19748)
* Move almost all functions' parameter db.Engine to context.Context * remove some unnecessary wrap functions
Diffstat (limited to 'models/auth/oauth2.go')
-rw-r--r--models/auth/oauth2.go134
1 files changed, 37 insertions, 97 deletions
diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go
index ca77fcdb78..c5c6e91120 100644
--- a/models/auth/oauth2.go
+++ b/models/auth/oauth2.go
@@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
}
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
-func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
- return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
-}
-
-func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
+func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
- if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
+ if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant
}
// CreateGrant generates a grant for an user
-func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
- return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
-}
-
-func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
+func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
grant := &OAuth2Grant{
ApplicationID: app.ID,
UserID: userID,
Scope: scope,
}
- _, err := e.Insert(grant)
+ err := db.Insert(ctx, grant)
if err != nil {
return nil, err
}
@@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin
}
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
-func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
- return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
-}
-
-func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
+func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
- has, err := e.Where("client_id = ?", clientID).Get(app)
+ has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
if !has {
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
}
@@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap
}
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
-func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
- return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
-}
-
-func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
+func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
- has, err := e.ID(id).Get(app)
+ has, err := db.GetEngine(ctx).ID(id).Get(app)
if err != nil {
return nil, err
}
@@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er
}
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
-func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
- return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
-}
-
-func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
+func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) {
apps = make([]*OAuth2Application, 0)
- err = e.Where("uid = ?", userID).Find(&apps)
+ err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps)
return
}
@@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct {
}
// CreateOAuth2Application inserts a new oauth2 application
-func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
- return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
-}
-
-func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
+func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
clientID := uuid.New().String()
app := &OAuth2Application{
UID: opts.UserID,
@@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (
ClientID: clientID,
RedirectURIs: opts.RedirectURIs,
}
- if _, err := e.Insert(app); err != nil {
+ if err := db.Insert(ctx, app); err != nil {
return nil, err
}
return app, nil
@@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return nil, err
}
defer committer.Close()
- sess := db.GetEngine(ctx)
- app, err := getOAuth2ApplicationByID(sess, opts.ID)
+ app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
if err != nil {
return nil, err
}
@@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
app.Name = opts.Name
app.RedirectURIs = opts.RedirectURIs
- if err = updateOAuth2Application(sess, app); err != nil {
+ if err = updateOAuth2Application(ctx, app); err != nil {
return nil, err
}
app.ClientSecret = ""
@@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return app, committer.Commit()
}
-func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
- if _, err := e.ID(app.ID).Update(app); err != nil {
+func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
+ if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil {
return err
}
return nil
}
-func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
+func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
+ sess := db.GetEngine(ctx)
if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
return err
} else if deleted == 0 {
@@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error {
return err
}
defer committer.Close()
- if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
+ if err := deleteOAuth2Application(ctx, id, userid); err != nil {
return err
}
return committer.Commit()
@@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect
}
// Invalidate deletes the auth code from the database to invalidate this code
-func (code *OAuth2AuthorizationCode) Invalidate() error {
- return code.invalidate(db.GetEngine(db.DefaultContext))
-}
-
-func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
- _, err := e.Delete(code)
+func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
+ _, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
return err
}
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
- return code.validateCodeChallenge(verifier)
-}
-
-func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
switch code.CodeChallengeMethod {
case "S256":
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
@@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool
}
// GetOAuth2AuthorizationByCode returns an authorization by its code
-func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
- return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
-}
-
-func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
+func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
auth = new(OAuth2AuthorizationCode)
- if has, err := e.Where("code = ?", code).Get(auth); err != nil {
+ if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
return nil, err
} else if !has {
return nil, nil
}
auth.Grant = new(OAuth2Grant)
- if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
+ if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string {
}
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
-func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
- return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
-}
-
-func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
+func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
rBytes, err := util.CryptoRandomBytes(32)
if err != nil {
return &OAuth2AuthorizationCode{}, err
@@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}
- if _, err := e.Insert(code); err != nil {
+ if err := db.Insert(ctx, code); err != nil {
return nil, err
}
return code, nil
}
// IncreaseCounter increases the counter and updates the grant
-func (grant *OAuth2Grant) IncreaseCounter() error {
- return grant.increaseCount(db.GetEngine(db.DefaultContext))
-}
-
-func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
- _, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
+func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
+ _, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
if err != nil {
return err
}
- updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
+ updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
if err != nil {
return err
}
@@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool {
}
// SetNonce updates the current nonce value of a grant
-func (grant *OAuth2Grant) SetNonce(nonce string) error {
- return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
-}
-
-func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
+func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
grant.Nonce = nonce
- _, err := e.ID(grant.ID).Cols("nonce").Update(grant)
+ _, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
if err != nil {
return err
}
@@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
}
// GetOAuth2GrantByID returns the grant with the given ID
-func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
- return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
-}
-
-func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
+func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
- if has, err := e.ID(id).Get(grant); err != nil {
+ if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
}
// GetOAuth2GrantsByUserID lists all grants of a certain user
-func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
- return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
-}
-
-func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
+func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
type joinedOAuth2Grant struct {
Grant *OAuth2Grant `xorm:"extends"`
Application *OAuth2Application `xorm:"extends"`
}
var results *xorm.Rows
var err error
- if results, err = e.
+ if results, err = db.GetEngine(ctx).
Table("oauth2_grant").
Where("user_id = ?", uid).
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
@@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
}
// RevokeOAuth2Grant deletes the grant with grantID and userID
-func RevokeOAuth2Grant(grantID, userID int64) error {
- return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
-}
-
-func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
- _, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
+func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
+ _, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID})
return err
}