aboutsummaryrefslogtreecommitdiffstats
path: root/models/auth
diff options
context:
space:
mode:
Diffstat (limited to 'models/auth')
-rw-r--r--models/auth/oauth2.go76
-rw-r--r--models/auth/session.go84
2 files changed, 71 insertions, 89 deletions
diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go
index 55af4e9036..d664841306 100644
--- a/models/auth/oauth2.go
+++ b/models/auth/oauth2.go
@@ -289,35 +289,31 @@ type UpdateOAuth2ApplicationOptions struct {
// UpdateOAuth2Application updates an oauth2 application
func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
- ctx, committer, err := db.TxContext(ctx)
- if err != nil {
- return nil, err
- }
- defer committer.Close()
-
- app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
- if err != nil {
- return nil, err
- }
- if app.UID != opts.UserID {
- return nil, errors.New("UID mismatch")
- }
- builtinApps := BuiltinApplications()
- if _, builtin := builtinApps[app.ClientID]; builtin {
- return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID)
- }
+ return db.WithTx2(ctx, func(ctx context.Context) (*OAuth2Application, error) {
+ app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
+ if err != nil {
+ return nil, err
+ }
+ if app.UID != opts.UserID {
+ return nil, errors.New("UID mismatch")
+ }
+ builtinApps := BuiltinApplications()
+ if _, builtin := builtinApps[app.ClientID]; builtin {
+ return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID)
+ }
- app.Name = opts.Name
- app.RedirectURIs = opts.RedirectURIs
- app.ConfidentialClient = opts.ConfidentialClient
- app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization
+ app.Name = opts.Name
+ app.RedirectURIs = opts.RedirectURIs
+ app.ConfidentialClient = opts.ConfidentialClient
+ app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization
- if err = updateOAuth2Application(ctx, app); err != nil {
- return nil, err
- }
- app.ClientSecret = ""
+ if err = updateOAuth2Application(ctx, app); err != nil {
+ return nil, err
+ }
+ app.ClientSecret = ""
- return app, committer.Commit()
+ return app, nil
+ })
}
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
@@ -358,23 +354,17 @@ func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
func DeleteOAuth2Application(ctx context.Context, id, userid int64) error {
- ctx, committer, err := db.TxContext(ctx)
- if err != nil {
- return err
- }
- defer committer.Close()
- app, err := GetOAuth2ApplicationByID(ctx, id)
- if err != nil {
- return err
- }
- builtinApps := BuiltinApplications()
- if _, builtin := builtinApps[app.ClientID]; builtin {
- return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID)
- }
- if err := deleteOAuth2Application(ctx, id, userid); err != nil {
- return err
- }
- return committer.Commit()
+ return db.WithTx(ctx, func(ctx context.Context) error {
+ app, err := GetOAuth2ApplicationByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ builtinApps := BuiltinApplications()
+ if _, builtin := builtinApps[app.ClientID]; builtin {
+ return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID)
+ }
+ return deleteOAuth2Application(ctx, id, userid)
+ })
}
//////////////////////////////////////////////////////
diff --git a/models/auth/session.go b/models/auth/session.go
index 75a205f702..0378d0ec6f 100644
--- a/models/auth/session.go
+++ b/models/auth/session.go
@@ -35,26 +35,22 @@ func UpdateSession(ctx context.Context, key string, data []byte) error {
// ReadSession reads the data for the provided session
func ReadSession(ctx context.Context, key string) (*Session, error) {
- ctx, committer, err := db.TxContext(ctx)
- if err != nil {
- return nil, err
- }
- defer committer.Close()
-
- session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key})
- if err != nil {
- return nil, err
- } else if !exist {
- session = &Session{
- Key: key,
- Expiry: timeutil.TimeStampNow(),
- }
- if err := db.Insert(ctx, session); err != nil {
+ return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
+ session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key})
+ if err != nil {
return nil, err
+ } else if !exist {
+ session = &Session{
+ Key: key,
+ Expiry: timeutil.TimeStampNow(),
+ }
+ if err := db.Insert(ctx, session); err != nil {
+ return nil, err
+ }
}
- }
- return session, committer.Commit()
+ return session, nil
+ })
}
// ExistSession checks if a session exists
@@ -72,40 +68,36 @@ func DestroySession(ctx context.Context, key string) error {
// RegenerateSession regenerates a session from the old id
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
- ctx, committer, err := db.TxContext(ctx)
- if err != nil {
- return nil, err
- }
- defer committer.Close()
-
- if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil {
- return nil, err
- } else if has {
- return nil, fmt.Errorf("session Key: %s already exists", newKey)
- }
-
- if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil {
- return nil, err
- } else if !has {
- if err := db.Insert(ctx, &Session{
- Key: oldKey,
- Expiry: timeutil.TimeStampNow(),
- }); err != nil {
+ return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
+ if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil {
+ return nil, err
+ } else if has {
+ return nil, fmt.Errorf("session Key: %s already exists", newKey)
+ }
+
+ if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil {
return nil, err
+ } else if !has {
+ if err := db.Insert(ctx, &Session{
+ Key: oldKey,
+ Expiry: timeutil.TimeStampNow(),
+ }); err != nil {
+ return nil, err
+ }
}
- }
- if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
- return nil, err
- }
+ if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
+ return nil, err
+ }
- s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey})
- if err != nil {
- // is not exist, it should be impossible
- return nil, err
- }
+ s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey})
+ if err != nil {
+ // is not exist, it should be impossible
+ return nil, err
+ }
- return s, committer.Commit()
+ return s, nil
+ })
}
// CountSessions returns the number of sessions