aboutsummaryrefslogtreecommitdiffstats
path: root/routers/web
diff options
context:
space:
mode:
Diffstat (limited to 'routers/web')
-rw-r--r--routers/web/auth/auth.go4
-rw-r--r--routers/web/auth/auth_test.go5
-rw-r--r--routers/web/auth/linkaccount.go15
-rw-r--r--routers/web/auth/oauth.go29
-rw-r--r--routers/web/auth/oauth_signin_sync.go9
5 files changed, 40 insertions, 22 deletions
diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go
index 13cd083771..2ccd1c71b5 100644
--- a/routers/web/auth/auth.go
+++ b/routers/web/auth/auth.go
@@ -565,7 +565,7 @@ func createUserInContext(ctx *context.Context, tpl templates.TplName, form any,
oauth2LinkAccount(ctx, user, possibleLinkAccountData, true)
return false // user is already created here, all redirects are handled
case setting.OAuth2AccountLinkingLogin:
- showLinkingLogin(ctx, &possibleLinkAccountData.AuthSource, possibleLinkAccountData.GothUser)
+ showLinkingLogin(ctx, possibleLinkAccountData.AuthSourceID, possibleLinkAccountData.GothUser)
return false // user will be created only after linking login
}
}
@@ -633,7 +633,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, possibleLinkAcc
// update external user information
if possibleLinkAccountData != nil {
- if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSource.ID, u, possibleLinkAccountData.GothUser); err != nil {
+ if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSourceID, u, possibleLinkAccountData.GothUser); err != nil {
log.Error("EnsureLinkExternalToUser failed: %v", err)
}
}
diff --git a/routers/web/auth/auth_test.go b/routers/web/auth/auth_test.go
index e238125407..a0fd5c0e50 100644
--- a/routers/web/auth/auth_test.go
+++ b/routers/web/auth/auth_test.go
@@ -64,13 +64,14 @@ func TestUserLogin(t *testing.T) {
func TestSignUpOAuth2Login(t *testing.T) {
defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)()
+ _ = oauth2.Init(t.Context())
addOAuth2Source(t, "dummy-auth-source", oauth2.Source{})
t.Run("OAuth2MissingField", func(t *testing.T) {
defer test.MockVariableValue(&gothic.CompleteUserAuth, func(res http.ResponseWriter, req *http.Request) (goth.User, error) {
return goth.User{Provider: "dummy-auth-source", UserID: "dummy-user"}, nil
})()
- mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")}
+ mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")}
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback?code=dummy-code", mockOpt)
ctx.SetPathParam("provider", "dummy-auth-source")
SignInOAuthCallback(ctx)
@@ -84,7 +85,7 @@ func TestSignUpOAuth2Login(t *testing.T) {
})
t.Run("OAuth2CallbackError", func(t *testing.T) {
- mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")}
+ mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")}
ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback", mockOpt)
ctx.SetPathParam("provider", "dummy-auth-source")
SignInOAuthCallback(ctx)
diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go
index cf1aa302c4..c624d896ca 100644
--- a/routers/web/auth/linkaccount.go
+++ b/routers/web/auth/linkaccount.go
@@ -170,7 +170,7 @@ func LinkAccountPostSignIn(ctx *context.Context) {
}
func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData *LinkAccountData, remember bool) {
- oauth2SignInSync(ctx, &linkAccountData.AuthSource, u, linkAccountData.GothUser)
+ oauth2SignInSync(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
if ctx.Written() {
return
}
@@ -185,7 +185,7 @@ func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData
return
}
- err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, u, linkAccountData.GothUser)
+ err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser)
if err != nil {
ctx.ServerError("UserLinkAccount", err)
return
@@ -295,7 +295,7 @@ func LinkAccountPostRegister(ctx *context.Context) {
Email: form.Email,
Passwd: form.Password,
LoginType: auth.OAuth2,
- LoginSource: linkAccountData.AuthSource.ID,
+ LoginSource: linkAccountData.AuthSourceID,
LoginName: linkAccountData.GothUser.UserID,
}
@@ -304,7 +304,12 @@ func LinkAccountPostRegister(ctx *context.Context) {
return
}
- source := linkAccountData.AuthSource.Cfg.(*oauth2.Source)
+ authSource, err := auth.GetSourceByID(ctx, linkAccountData.AuthSourceID)
+ if err != nil {
+ ctx.ServerError("GetSourceByID", err)
+ return
+ }
+ source := authSource.Cfg.(*oauth2.Source)
if err := syncGroupsToTeams(ctx, source, &linkAccountData.GothUser, u); err != nil {
ctx.ServerError("SyncGroupsToTeams", err)
return
@@ -318,5 +323,5 @@ func linkAccountFromContext(ctx *context.Context, user *user_model.User) error {
if linkAccountData == nil {
return errors.New("not in LinkAccount session")
}
- return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, user, linkAccountData.GothUser)
+ return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, user, linkAccountData.GothUser)
}
diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go
index 3df2734bb6..f1c155e78f 100644
--- a/routers/web/auth/oauth.go
+++ b/routers/web/auth/oauth.go
@@ -4,6 +4,7 @@
package auth
import (
+ "encoding/gob"
"errors"
"fmt"
"html"
@@ -171,7 +172,7 @@ func SignInOAuthCallback(ctx *context.Context) {
gothUser.RawData = make(map[string]any)
}
gothUser.RawData["__giteaAutoRegMissingFields"] = missingFields
- showLinkingLogin(ctx, authSource, gothUser)
+ showLinkingLogin(ctx, authSource.ID, gothUser)
return
}
u = &user_model.User{
@@ -192,7 +193,7 @@ func SignInOAuthCallback(ctx *context.Context) {
u.IsAdmin = isAdmin.ValueOrDefault(user_service.UpdateOptionField[bool]{FieldValue: false}).FieldValue
u.IsRestricted = isRestricted.ValueOrDefault(setting.Service.DefaultUserIsRestricted)
- linkAccountData := &LinkAccountData{*authSource, gothUser}
+ linkAccountData := &LinkAccountData{authSource.ID, gothUser}
if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingDisabled {
linkAccountData = nil
}
@@ -207,7 +208,7 @@ func SignInOAuthCallback(ctx *context.Context) {
}
} else {
// no existing user is found, request attach or new account
- showLinkingLogin(ctx, authSource, gothUser)
+ showLinkingLogin(ctx, authSource.ID, gothUser)
return
}
}
@@ -272,11 +273,12 @@ func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *g
}
type LinkAccountData struct {
- AuthSource auth.Source
- GothUser goth.User
+ AuthSourceID int64
+ GothUser goth.User
}
func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData {
+ gob.Register(LinkAccountData{})
v, ok := ctx.Session.Get("linkAccountData").(LinkAccountData)
if !ok {
return nil
@@ -284,11 +286,16 @@ func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData {
return &v
}
-func showLinkingLogin(ctx *context.Context, authSource *auth.Source, gothUser goth.User) {
- if err := updateSession(ctx, nil, map[string]any{
- "linkAccountData": LinkAccountData{*authSource, gothUser},
- }); err != nil {
- ctx.ServerError("updateSession", err)
+func Oauth2SetLinkAccountData(ctx *context.Context, linkAccountData LinkAccountData) error {
+ gob.Register(LinkAccountData{})
+ return updateSession(ctx, nil, map[string]any{
+ "linkAccountData": linkAccountData,
+ })
+}
+
+func showLinkingLogin(ctx *context.Context, authSourceID int64, gothUser goth.User) {
+ if err := Oauth2SetLinkAccountData(ctx, LinkAccountData{authSourceID, gothUser}); err != nil {
+ ctx.ServerError("Oauth2SetLinkAccountData", err)
return
}
ctx.Redirect(setting.AppSubURL + "/user/link_account")
@@ -313,7 +320,7 @@ func oauth2UpdateAvatarIfNeed(ctx *context.Context, url string, u *user_model.Us
}
func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
- oauth2SignInSync(ctx, authSource, u, gothUser)
+ oauth2SignInSync(ctx, authSource.ID, u, gothUser)
if ctx.Written() {
return
}
diff --git a/routers/web/auth/oauth_signin_sync.go b/routers/web/auth/oauth_signin_sync.go
index 787ea9223c..86d1966024 100644
--- a/routers/web/auth/oauth_signin_sync.go
+++ b/routers/web/auth/oauth_signin_sync.go
@@ -18,9 +18,14 @@ import (
"github.com/markbates/goth"
)
-func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) {
+func oauth2SignInSync(ctx *context.Context, authSourceID int64, u *user_model.User, gothUser goth.User) {
oauth2UpdateAvatarIfNeed(ctx, gothUser.AvatarURL, u)
+ authSource, err := auth.GetSourceByID(ctx, authSourceID)
+ if err != nil {
+ ctx.ServerError("GetSourceByID", err)
+ return
+ }
oauth2Source, _ := authSource.Cfg.(*oauth2.Source)
if !authSource.IsOAuth2() || oauth2Source == nil {
ctx.ServerError("oauth2SignInSync", fmt.Errorf("source %s is not an OAuth2 source", gothUser.Provider))
@@ -45,7 +50,7 @@ func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_mod
}
}
- err := oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
+ err = oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u)
if err != nil {
log.Error("Unable to sync OAuth2 SSH public key %s: %v", gothUser.Provider, err)
}