From 86aafea3fbaa69df05a104df697b0bbfc4ce6d1b Mon Sep 17 00:00:00 2001 From: wxiaoguang Date: Sun, 20 Jul 2025 09:49:36 +0800 Subject: Fix session gob (#35128) Fix #35126 --- modules/session/mem.go | 68 +++++++++++++++++++++++++++++++++++ modules/session/mock.go | 26 -------------- modules/session/store.go | 22 ++++++------ modules/session/virtual.go | 6 ++-- routers/web/auth/auth.go | 4 +-- routers/web/auth/auth_test.go | 5 +-- routers/web/auth/linkaccount.go | 15 +++++--- routers/web/auth/oauth.go | 29 +++++++++------ routers/web/auth/oauth_signin_sync.go | 9 +++-- services/auth/source/oauth2/store.go | 15 ++++---- services/contexttest/context_tests.go | 2 +- tests/integration/signin_test.go | 2 +- 12 files changed, 131 insertions(+), 72 deletions(-) create mode 100644 modules/session/mem.go delete mode 100644 modules/session/mock.go diff --git a/modules/session/mem.go b/modules/session/mem.go new file mode 100644 index 0000000000..bb807bc91a --- /dev/null +++ b/modules/session/mem.go @@ -0,0 +1,68 @@ +// Copyright 2025 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package session + +import ( + "bytes" + "encoding/gob" + "net/http" + + "gitea.com/go-chi/session" +) + +type mockMemRawStore struct { + s *session.MemStore +} + +var _ session.RawStore = (*mockMemRawStore)(nil) + +func (m *mockMemRawStore) Set(k, v any) error { + // We need to use gob to encode the value, to make it have the same behavior as other stores and catch abuses. + // Because gob needs to "Register" the type before it can encode it, and it's unable to decode a struct to "any" so use a map to help to decode the value. + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(map[string]any{"v": v}); err != nil { + return err + } + return m.s.Set(k, buf.Bytes()) +} + +func (m *mockMemRawStore) Get(k any) (ret any) { + v, ok := m.s.Get(k).([]byte) + if !ok { + return nil + } + var w map[string]any + _ = gob.NewDecoder(bytes.NewBuffer(v)).Decode(&w) + return w["v"] +} + +func (m *mockMemRawStore) Delete(k any) error { + return m.s.Delete(k) +} + +func (m *mockMemRawStore) ID() string { + return m.s.ID() +} + +func (m *mockMemRawStore) Release() error { + return m.s.Release() +} + +func (m *mockMemRawStore) Flush() error { + return m.s.Flush() +} + +type mockMemStore struct { + *mockMemRawStore +} + +var _ Store = (*mockMemStore)(nil) + +func (m mockMemStore) Destroy(writer http.ResponseWriter, request *http.Request) error { + return nil +} + +func NewMockMemStore(sid string) Store { + return &mockMemStore{&mockMemRawStore{session.NewMemStore(sid)}} +} diff --git a/modules/session/mock.go b/modules/session/mock.go deleted file mode 100644 index 95231a3655..0000000000 --- a/modules/session/mock.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2024 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package session - -import ( - "net/http" - - "gitea.com/go-chi/session" -) - -type MockStore struct { - *session.MemStore -} - -func (m *MockStore) Destroy(writer http.ResponseWriter, request *http.Request) error { - return nil -} - -type mockStoreContextKeyStruct struct{} - -var MockStoreContextKey = mockStoreContextKeyStruct{} - -func NewMockStore(sid string) *MockStore { - return &MockStore{session.NewMemStore(sid)} -} diff --git a/modules/session/store.go b/modules/session/store.go index 09d1ef44dd..0217ed97ac 100644 --- a/modules/session/store.go +++ b/modules/session/store.go @@ -11,25 +11,25 @@ import ( "gitea.com/go-chi/session" ) -// Store represents a session store +type RawStore = session.RawStore + type Store interface { - Get(any) any - Set(any, any) error - Delete(any) error - ID() string - Release() error - Flush() error + RawStore Destroy(http.ResponseWriter, *http.Request) error } +type mockStoreContextKeyStruct struct{} + +var MockStoreContextKey = mockStoreContextKeyStruct{} + // RegenerateSession regenerates the underlying session and returns the new store func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) { for _, f := range BeforeRegenerateSession { f(resp, req) } if setting.IsInTesting { - if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok { - return store, nil + if store := req.Context().Value(MockStoreContextKey); store != nil { + return store.(Store), nil } } return session.RegenerateSession(resp, req) @@ -37,8 +37,8 @@ func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, erro func GetContextSession(req *http.Request) Store { if setting.IsInTesting { - if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok { - return store + if store := req.Context().Value(MockStoreContextKey); store != nil { + return store.(Store) } } return session.GetSession(req) diff --git a/modules/session/virtual.go b/modules/session/virtual.go index 80352b6e72..2e29b5fc6f 100644 --- a/modules/session/virtual.go +++ b/modules/session/virtual.go @@ -22,8 +22,8 @@ type VirtualSessionProvider struct { provider session.Provider } -// Init initializes the cookie session provider with given root path. -func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error { +// Init initializes the cookie session provider with the given config. +func (o *VirtualSessionProvider) Init(gcLifetime int64, config string) error { var opts session.Options if err := json.Unmarshal([]byte(config), &opts); err != nil { return err @@ -52,7 +52,7 @@ func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error { default: return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider) } - return o.provider.Init(gclifetime, opts.ProviderConfig) + return o.provider.Init(gcLifetime, opts.ProviderConfig) } // Read returns raw session store by session ID. diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index 13cd083771..2ccd1c71b5 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -565,7 +565,7 @@ func createUserInContext(ctx *context.Context, tpl templates.TplName, form any, oauth2LinkAccount(ctx, user, possibleLinkAccountData, true) return false // user is already created here, all redirects are handled case setting.OAuth2AccountLinkingLogin: - showLinkingLogin(ctx, &possibleLinkAccountData.AuthSource, possibleLinkAccountData.GothUser) + showLinkingLogin(ctx, possibleLinkAccountData.AuthSourceID, possibleLinkAccountData.GothUser) return false // user will be created only after linking login } } @@ -633,7 +633,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, possibleLinkAcc // update external user information if possibleLinkAccountData != nil { - if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSource.ID, u, possibleLinkAccountData.GothUser); err != nil { + if err := externalaccount.EnsureLinkExternalToUser(ctx, possibleLinkAccountData.AuthSourceID, u, possibleLinkAccountData.GothUser); err != nil { log.Error("EnsureLinkExternalToUser failed: %v", err) } } diff --git a/routers/web/auth/auth_test.go b/routers/web/auth/auth_test.go index e238125407..a0fd5c0e50 100644 --- a/routers/web/auth/auth_test.go +++ b/routers/web/auth/auth_test.go @@ -64,13 +64,14 @@ func TestUserLogin(t *testing.T) { func TestSignUpOAuth2Login(t *testing.T) { defer test.MockVariableValue(&setting.OAuth2Client.EnableAutoRegistration, true)() + _ = oauth2.Init(t.Context()) addOAuth2Source(t, "dummy-auth-source", oauth2.Source{}) t.Run("OAuth2MissingField", func(t *testing.T) { defer test.MockVariableValue(&gothic.CompleteUserAuth, func(res http.ResponseWriter, req *http.Request) (goth.User, error) { return goth.User{Provider: "dummy-auth-source", UserID: "dummy-user"}, nil })() - mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")} + mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")} ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback?code=dummy-code", mockOpt) ctx.SetPathParam("provider", "dummy-auth-source") SignInOAuthCallback(ctx) @@ -84,7 +85,7 @@ func TestSignUpOAuth2Login(t *testing.T) { }) t.Run("OAuth2CallbackError", func(t *testing.T) { - mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockStore("dummy-sid")} + mockOpt := contexttest.MockContextOption{SessionStore: session.NewMockMemStore("dummy-sid")} ctx, resp := contexttest.MockContext(t, "/user/oauth2/dummy-auth-source/callback", mockOpt) ctx.SetPathParam("provider", "dummy-auth-source") SignInOAuthCallback(ctx) diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index cf1aa302c4..c624d896ca 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -170,7 +170,7 @@ func LinkAccountPostSignIn(ctx *context.Context) { } func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData *LinkAccountData, remember bool) { - oauth2SignInSync(ctx, &linkAccountData.AuthSource, u, linkAccountData.GothUser) + oauth2SignInSync(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser) if ctx.Written() { return } @@ -185,7 +185,7 @@ func oauth2LinkAccount(ctx *context.Context, u *user_model.User, linkAccountData return } - err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, u, linkAccountData.GothUser) + err = externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, u, linkAccountData.GothUser) if err != nil { ctx.ServerError("UserLinkAccount", err) return @@ -295,7 +295,7 @@ func LinkAccountPostRegister(ctx *context.Context) { Email: form.Email, Passwd: form.Password, LoginType: auth.OAuth2, - LoginSource: linkAccountData.AuthSource.ID, + LoginSource: linkAccountData.AuthSourceID, LoginName: linkAccountData.GothUser.UserID, } @@ -304,7 +304,12 @@ func LinkAccountPostRegister(ctx *context.Context) { return } - source := linkAccountData.AuthSource.Cfg.(*oauth2.Source) + authSource, err := auth.GetSourceByID(ctx, linkAccountData.AuthSourceID) + if err != nil { + ctx.ServerError("GetSourceByID", err) + return + } + source := authSource.Cfg.(*oauth2.Source) if err := syncGroupsToTeams(ctx, source, &linkAccountData.GothUser, u); err != nil { ctx.ServerError("SyncGroupsToTeams", err) return @@ -318,5 +323,5 @@ func linkAccountFromContext(ctx *context.Context, user *user_model.User) error { if linkAccountData == nil { return errors.New("not in LinkAccount session") } - return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSource.ID, user, linkAccountData.GothUser) + return externalaccount.LinkAccountToUser(ctx, linkAccountData.AuthSourceID, user, linkAccountData.GothUser) } diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 3df2734bb6..f1c155e78f 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -4,6 +4,7 @@ package auth import ( + "encoding/gob" "errors" "fmt" "html" @@ -171,7 +172,7 @@ func SignInOAuthCallback(ctx *context.Context) { gothUser.RawData = make(map[string]any) } gothUser.RawData["__giteaAutoRegMissingFields"] = missingFields - showLinkingLogin(ctx, authSource, gothUser) + showLinkingLogin(ctx, authSource.ID, gothUser) return } u = &user_model.User{ @@ -192,7 +193,7 @@ func SignInOAuthCallback(ctx *context.Context) { u.IsAdmin = isAdmin.ValueOrDefault(user_service.UpdateOptionField[bool]{FieldValue: false}).FieldValue u.IsRestricted = isRestricted.ValueOrDefault(setting.Service.DefaultUserIsRestricted) - linkAccountData := &LinkAccountData{*authSource, gothUser} + linkAccountData := &LinkAccountData{authSource.ID, gothUser} if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingDisabled { linkAccountData = nil } @@ -207,7 +208,7 @@ func SignInOAuthCallback(ctx *context.Context) { } } else { // no existing user is found, request attach or new account - showLinkingLogin(ctx, authSource, gothUser) + showLinkingLogin(ctx, authSource.ID, gothUser) return } } @@ -272,11 +273,12 @@ func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *g } type LinkAccountData struct { - AuthSource auth.Source - GothUser goth.User + AuthSourceID int64 + GothUser goth.User } func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData { + gob.Register(LinkAccountData{}) v, ok := ctx.Session.Get("linkAccountData").(LinkAccountData) if !ok { return nil @@ -284,11 +286,16 @@ func oauth2GetLinkAccountData(ctx *context.Context) *LinkAccountData { return &v } -func showLinkingLogin(ctx *context.Context, authSource *auth.Source, gothUser goth.User) { - if err := updateSession(ctx, nil, map[string]any{ - "linkAccountData": LinkAccountData{*authSource, gothUser}, - }); err != nil { - ctx.ServerError("updateSession", err) +func Oauth2SetLinkAccountData(ctx *context.Context, linkAccountData LinkAccountData) error { + gob.Register(LinkAccountData{}) + return updateSession(ctx, nil, map[string]any{ + "linkAccountData": linkAccountData, + }) +} + +func showLinkingLogin(ctx *context.Context, authSourceID int64, gothUser goth.User) { + if err := Oauth2SetLinkAccountData(ctx, LinkAccountData{authSourceID, gothUser}); err != nil { + ctx.ServerError("Oauth2SetLinkAccountData", err) return } ctx.Redirect(setting.AppSubURL + "/user/link_account") @@ -313,7 +320,7 @@ func oauth2UpdateAvatarIfNeed(ctx *context.Context, url string, u *user_model.Us } func handleOAuth2SignIn(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) { - oauth2SignInSync(ctx, authSource, u, gothUser) + oauth2SignInSync(ctx, authSource.ID, u, gothUser) if ctx.Written() { return } diff --git a/routers/web/auth/oauth_signin_sync.go b/routers/web/auth/oauth_signin_sync.go index 787ea9223c..86d1966024 100644 --- a/routers/web/auth/oauth_signin_sync.go +++ b/routers/web/auth/oauth_signin_sync.go @@ -18,9 +18,14 @@ import ( "github.com/markbates/goth" ) -func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_model.User, gothUser goth.User) { +func oauth2SignInSync(ctx *context.Context, authSourceID int64, u *user_model.User, gothUser goth.User) { oauth2UpdateAvatarIfNeed(ctx, gothUser.AvatarURL, u) + authSource, err := auth.GetSourceByID(ctx, authSourceID) + if err != nil { + ctx.ServerError("GetSourceByID", err) + return + } oauth2Source, _ := authSource.Cfg.(*oauth2.Source) if !authSource.IsOAuth2() || oauth2Source == nil { ctx.ServerError("oauth2SignInSync", fmt.Errorf("source %s is not an OAuth2 source", gothUser.Provider)) @@ -45,7 +50,7 @@ func oauth2SignInSync(ctx *context.Context, authSource *auth.Source, u *user_mod } } - err := oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u) + err = oauth2UpdateSSHPubIfNeed(ctx, authSource, &gothUser, u) if err != nil { log.Error("Unable to sync OAuth2 SSH public key %s: %v", gothUser.Provider, err) } diff --git a/services/auth/source/oauth2/store.go b/services/auth/source/oauth2/store.go index 90fa965602..7b6b26edc8 100644 --- a/services/auth/source/oauth2/store.go +++ b/services/auth/source/oauth2/store.go @@ -11,7 +11,6 @@ import ( "code.gitea.io/gitea/modules/log" session_module "code.gitea.io/gitea/modules/session" - chiSession "gitea.com/go-chi/session" "github.com/gorilla/sessions" ) @@ -35,11 +34,11 @@ func (st *SessionsStore) New(r *http.Request, name string) (*sessions.Session, e // getOrNew gets the session from the chi-session if it exists. Override permits the overriding of an unexpected object. func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) (*sessions.Session, error) { - chiStore := chiSession.GetSession(r) + store := session_module.GetContextSession(r) session := sessions.NewSession(st, name) - rawData := chiStore.Get(name) + rawData := store.Get(name) if rawData != nil { oldSession, ok := rawData.(*sessions.Session) if ok { @@ -56,21 +55,21 @@ func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) ( } session.IsNew = override - session.ID = chiStore.ID() // Simply copy the session id from the chi store + session.ID = store.ID() // Simply copy the session id from the chi store - return session, chiStore.Set(name, session) + return session, store.Set(name, session) } // Save should persist session to the underlying store implementation. func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { - chiStore := chiSession.GetSession(r) + store := session_module.GetContextSession(r) if session.IsNew { _, _ = session_module.RegenerateSession(w, r) session.IsNew = false } - if err := chiStore.Set(session.Name(), session); err != nil { + if err := store.Set(session.Name(), session); err != nil { return err } @@ -83,7 +82,7 @@ func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *s } } - return chiStore.Release() + return store.Release() } type sizeWriter struct { diff --git a/services/contexttest/context_tests.go b/services/contexttest/context_tests.go index b54023897b..44d9f4a70f 100644 --- a/services/contexttest/context_tests.go +++ b/services/contexttest/context_tests.go @@ -49,7 +49,7 @@ func mockRequest(t *testing.T, reqPath string) *http.Request { type MockContextOption struct { Render context.Render - SessionStore *session.MockStore + SessionStore session.Store } // MockContext mock context for unit tests diff --git a/tests/integration/signin_test.go b/tests/integration/signin_test.go index aa1571c163..fa37145d98 100644 --- a/tests/integration/signin_test.go +++ b/tests/integration/signin_test.go @@ -107,7 +107,7 @@ func TestEnablePasswordSignInFormAndEnablePasskeyAuth(t *testing.T) { mockLinkAccount := func(ctx *context.Context) { authSource := auth_model.Source{ID: 1} gothUser := goth.User{Email: "invalid-email", Name: "."} - _ = ctx.Session.Set("linkAccountData", auth.LinkAccountData{AuthSource: authSource, GothUser: gothUser}) + _ = auth.Oauth2SetLinkAccountData(ctx, auth.LinkAccountData{AuthSourceID: authSource.ID, GothUser: gothUser}) } t.Run("EnablePasswordSignInForm=false", func(t *testing.T) { -- cgit v1.2.3