diff options
Diffstat (limited to 'models/auth')
-rw-r--r-- | models/auth/access_token_scope.go | 7 | ||||
-rw-r--r-- | models/auth/auth_token.go | 2 | ||||
-rw-r--r-- | models/auth/oauth2.go | 88 | ||||
-rw-r--r-- | models/auth/session.go | 84 | ||||
-rw-r--r-- | models/auth/source.go | 45 | ||||
-rw-r--r-- | models/auth/source_test.go | 2 | ||||
-rw-r--r-- | models/auth/twofactor.go | 10 |
7 files changed, 113 insertions, 125 deletions
diff --git a/models/auth/access_token_scope.go b/models/auth/access_token_scope.go index 2293fd89a0..3eae19b2a5 100644 --- a/models/auth/access_token_scope.go +++ b/models/auth/access_token_scope.go @@ -213,12 +213,7 @@ func GetRequiredScopes(level AccessTokenScopeLevel, scopeCategories ...AccessTok // ContainsCategory checks if a list of categories contains a specific category func ContainsCategory(categories []AccessTokenScopeCategory, category AccessTokenScopeCategory) bool { - for _, c := range categories { - if c == category { - return true - } - } - return false + return slices.Contains(categories, category) } // GetScopeLevelFromAccessMode converts permission access mode to scope level diff --git a/models/auth/auth_token.go b/models/auth/auth_token.go index 81f07d1a83..54ff5a0d75 100644 --- a/models/auth/auth_token.go +++ b/models/auth/auth_token.go @@ -15,7 +15,7 @@ import ( var ErrAuthTokenNotExist = util.NewNotExistErrorf("auth token does not exist") -type AuthToken struct { //nolint:revive +type AuthToken struct { //nolint:revive // export stutter ID string `xorm:"pk"` TokenHash string UserID int64 `xorm:"INDEX"` diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index c270e4856e..d664841306 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -12,6 +12,7 @@ import ( "fmt" "net" "net/url" + "slices" "strings" "code.gitea.io/gitea/models/db" @@ -288,35 +289,31 @@ type UpdateOAuth2ApplicationOptions struct { // UpdateOAuth2Application updates an oauth2 application func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) { - ctx, committer, err := db.TxContext(ctx) - if err != nil { - return nil, err - } - defer committer.Close() - - app, err := GetOAuth2ApplicationByID(ctx, opts.ID) - if err != nil { - return nil, err - } - if app.UID != opts.UserID { - return nil, errors.New("UID mismatch") - } - builtinApps := BuiltinApplications() - if _, builtin := builtinApps[app.ClientID]; builtin { - return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID) - } + return db.WithTx2(ctx, func(ctx context.Context) (*OAuth2Application, error) { + app, err := GetOAuth2ApplicationByID(ctx, opts.ID) + if err != nil { + return nil, err + } + if app.UID != opts.UserID { + return nil, errors.New("UID mismatch") + } + builtinApps := BuiltinApplications() + if _, builtin := builtinApps[app.ClientID]; builtin { + return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID) + } - app.Name = opts.Name - app.RedirectURIs = opts.RedirectURIs - app.ConfidentialClient = opts.ConfidentialClient - app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization + app.Name = opts.Name + app.RedirectURIs = opts.RedirectURIs + app.ConfidentialClient = opts.ConfidentialClient + app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization - if err = updateOAuth2Application(ctx, app); err != nil { - return nil, err - } - app.ClientSecret = "" + if err = updateOAuth2Application(ctx, app); err != nil { + return nil, err + } + app.ClientSecret = "" - return app, committer.Commit() + return app, nil + }) } func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error { @@ -357,23 +354,17 @@ func deleteOAuth2Application(ctx context.Context, id, userid int64) error { // DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app. func DeleteOAuth2Application(ctx context.Context, id, userid int64) error { - ctx, committer, err := db.TxContext(ctx) - if err != nil { - return err - } - defer committer.Close() - app, err := GetOAuth2ApplicationByID(ctx, id) - if err != nil { - return err - } - builtinApps := BuiltinApplications() - if _, builtin := builtinApps[app.ClientID]; builtin { - return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID) - } - if err := deleteOAuth2Application(ctx, id, userid); err != nil { - return err - } - return committer.Commit() + return db.WithTx(ctx, func(ctx context.Context) error { + app, err := GetOAuth2ApplicationByID(ctx, id) + if err != nil { + return err + } + builtinApps := BuiltinApplications() + if _, builtin := builtinApps[app.ClientID]; builtin { + return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID) + } + return deleteOAuth2Application(ctx, id, userid) + }) } ////////////////////////////////////////////////////// @@ -511,12 +502,7 @@ func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error { // ScopeContains returns true if the grant scope contains the specified scope func (grant *OAuth2Grant) ScopeContains(scope string) bool { - for _, currentScope := range strings.Split(grant.Scope, " ") { - if scope == currentScope { - return true - } - } - return false + return slices.Contains(strings.Split(grant.Scope, " "), scope) } // SetNonce updates the current nonce value of a grant @@ -616,8 +602,8 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error { return util.ErrNotExist } -// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name -func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) { +// GetActiveOAuth2SourceByAuthName returns a OAuth2 AuthSource based on the given name +func GetActiveOAuth2SourceByAuthName(ctx context.Context, name string) (*Source, error) { authSource := new(Source) has, err := db.GetEngine(ctx).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource) if err != nil { diff --git a/models/auth/session.go b/models/auth/session.go index 75a205f702..0378d0ec6f 100644 --- a/models/auth/session.go +++ b/models/auth/session.go @@ -35,26 +35,22 @@ func UpdateSession(ctx context.Context, key string, data []byte) error { // ReadSession reads the data for the provided session func ReadSession(ctx context.Context, key string) (*Session, error) { - ctx, committer, err := db.TxContext(ctx) - if err != nil { - return nil, err - } - defer committer.Close() - - session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key}) - if err != nil { - return nil, err - } else if !exist { - session = &Session{ - Key: key, - Expiry: timeutil.TimeStampNow(), - } - if err := db.Insert(ctx, session); err != nil { + return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) { + session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key}) + if err != nil { return nil, err + } else if !exist { + session = &Session{ + Key: key, + Expiry: timeutil.TimeStampNow(), + } + if err := db.Insert(ctx, session); err != nil { + return nil, err + } } - } - return session, committer.Commit() + return session, nil + }) } // ExistSession checks if a session exists @@ -72,40 +68,36 @@ func DestroySession(ctx context.Context, key string) error { // RegenerateSession regenerates a session from the old id func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) { - ctx, committer, err := db.TxContext(ctx) - if err != nil { - return nil, err - } - defer committer.Close() - - if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil { - return nil, err - } else if has { - return nil, fmt.Errorf("session Key: %s already exists", newKey) - } - - if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil { - return nil, err - } else if !has { - if err := db.Insert(ctx, &Session{ - Key: oldKey, - Expiry: timeutil.TimeStampNow(), - }); err != nil { + return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) { + if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil { + return nil, err + } else if has { + return nil, fmt.Errorf("session Key: %s already exists", newKey) + } + + if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil { return nil, err + } else if !has { + if err := db.Insert(ctx, &Session{ + Key: oldKey, + Expiry: timeutil.TimeStampNow(), + }); err != nil { + return nil, err + } } - } - if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil { - return nil, err - } + if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil { + return nil, err + } - s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey}) - if err != nil { - // is not exist, it should be impossible - return nil, err - } + s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey}) + if err != nil { + // is not exist, it should be impossible + return nil, err + } - return s, committer.Commit() + return s, nil + }) } // CountSessions returns the number of sessions diff --git a/models/auth/source.go b/models/auth/source.go index a3a250cd91..08cfc9615b 100644 --- a/models/auth/source.go +++ b/models/auth/source.go @@ -58,6 +58,15 @@ var Names = map[Type]string{ // Config represents login config as far as the db is concerned type Config interface { convert.Conversion + SetAuthSource(*Source) +} + +type ConfigBase struct { + AuthSource *Source +} + +func (p *ConfigBase) SetAuthSource(s *Source) { + p.AuthSource = s } // SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set @@ -104,19 +113,15 @@ func RegisterTypeConfig(typ Type, exemplar Config) { } } -// SourceSettable configurations can have their authSource set on them -type SourceSettable interface { - SetAuthSource(*Source) -} - // Source represents an external way for authorizing users. type Source struct { - ID int64 `xorm:"pk autoincr"` - Type Type - Name string `xorm:"UNIQUE"` - IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"` - IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"` - Cfg convert.Conversion `xorm:"TEXT"` + ID int64 `xorm:"pk autoincr"` + Type Type + Name string `xorm:"UNIQUE"` + IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"` + IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"` + TwoFactorPolicy string `xorm:"two_factor_policy NOT NULL DEFAULT ''"` + Cfg Config `xorm:"TEXT"` CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` @@ -140,9 +145,7 @@ func (source *Source) BeforeSet(colName string, val xorm.Cell) { return } source.Cfg = constructor() - if settable, ok := source.Cfg.(SourceSettable); ok { - settable.SetAuthSource(source) - } + source.Cfg.SetAuthSource(source) } } @@ -200,6 +203,10 @@ func (source *Source) SkipVerify() bool { return ok && skipVerifiable.IsSkipVerify() } +func (source *Source) TwoFactorShouldSkip() bool { + return source.TwoFactorPolicy == "skip" +} + // CreateSource inserts a AuthSource in the DB if not already // existing with the given name. func CreateSource(ctx context.Context, source *Source) error { @@ -223,9 +230,7 @@ func CreateSource(ctx context.Context, source *Source) error { return nil } - if settable, ok := source.Cfg.(SourceSettable); ok { - settable.SetAuthSource(source) - } + source.Cfg.SetAuthSource(source) registerableSource, ok := source.Cfg.(RegisterableSource) if !ok { @@ -320,9 +325,7 @@ func UpdateSource(ctx context.Context, source *Source) error { return nil } - if settable, ok := source.Cfg.(SourceSettable); ok { - settable.SetAuthSource(source) - } + source.Cfg.SetAuthSource(source) registerableSource, ok := source.Cfg.(RegisterableSource) if !ok { @@ -331,7 +334,7 @@ func UpdateSource(ctx context.Context, source *Source) error { err = registerableSource.RegisterSource() if err != nil { - // restore original values since we cannot update the provider it self + // restore original values since we cannot update the provider itself if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil { log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err) } diff --git a/models/auth/source_test.go b/models/auth/source_test.go index 84aede0a6b..64c7460b64 100644 --- a/models/auth/source_test.go +++ b/models/auth/source_test.go @@ -19,6 +19,8 @@ import ( ) type TestSource struct { + auth_model.ConfigBase + Provider string ClientID string ClientSecret string diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go index d0c341a192..200ce7c7c0 100644 --- a/models/auth/twofactor.go +++ b/models/auth/twofactor.go @@ -164,3 +164,13 @@ func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error { } return nil } + +func HasTwoFactorOrWebAuthn(ctx context.Context, id int64) (bool, error) { + has, err := HasTwoFactorByUID(ctx, id) + if err != nil { + return false, err + } else if has { + return true, nil + } + return HasWebAuthnRegistrationsByUID(ctx, id) +} |