diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/session/mem.go | 68 | ||||
-rw-r--r-- | modules/session/mock.go | 26 | ||||
-rw-r--r-- | modules/session/store.go | 22 | ||||
-rw-r--r-- | modules/session/virtual.go | 6 |
4 files changed, 82 insertions, 40 deletions
diff --git a/modules/session/mem.go b/modules/session/mem.go new file mode 100644 index 0000000000..bb807bc91a --- /dev/null +++ b/modules/session/mem.go @@ -0,0 +1,68 @@ +// Copyright 2025 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package session + +import ( + "bytes" + "encoding/gob" + "net/http" + + "gitea.com/go-chi/session" +) + +type mockMemRawStore struct { + s *session.MemStore +} + +var _ session.RawStore = (*mockMemRawStore)(nil) + +func (m *mockMemRawStore) Set(k, v any) error { + // We need to use gob to encode the value, to make it have the same behavior as other stores and catch abuses. + // Because gob needs to "Register" the type before it can encode it, and it's unable to decode a struct to "any" so use a map to help to decode the value. + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(map[string]any{"v": v}); err != nil { + return err + } + return m.s.Set(k, buf.Bytes()) +} + +func (m *mockMemRawStore) Get(k any) (ret any) { + v, ok := m.s.Get(k).([]byte) + if !ok { + return nil + } + var w map[string]any + _ = gob.NewDecoder(bytes.NewBuffer(v)).Decode(&w) + return w["v"] +} + +func (m *mockMemRawStore) Delete(k any) error { + return m.s.Delete(k) +} + +func (m *mockMemRawStore) ID() string { + return m.s.ID() +} + +func (m *mockMemRawStore) Release() error { + return m.s.Release() +} + +func (m *mockMemRawStore) Flush() error { + return m.s.Flush() +} + +type mockMemStore struct { + *mockMemRawStore +} + +var _ Store = (*mockMemStore)(nil) + +func (m mockMemStore) Destroy(writer http.ResponseWriter, request *http.Request) error { + return nil +} + +func NewMockMemStore(sid string) Store { + return &mockMemStore{&mockMemRawStore{session.NewMemStore(sid)}} +} diff --git a/modules/session/mock.go b/modules/session/mock.go deleted file mode 100644 index 95231a3655..0000000000 --- a/modules/session/mock.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2024 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package session - -import ( - "net/http" - - "gitea.com/go-chi/session" -) - -type MockStore struct { - *session.MemStore -} - -func (m *MockStore) Destroy(writer http.ResponseWriter, request *http.Request) error { - return nil -} - -type mockStoreContextKeyStruct struct{} - -var MockStoreContextKey = mockStoreContextKeyStruct{} - -func NewMockStore(sid string) *MockStore { - return &MockStore{session.NewMemStore(sid)} -} diff --git a/modules/session/store.go b/modules/session/store.go index 09d1ef44dd..0217ed97ac 100644 --- a/modules/session/store.go +++ b/modules/session/store.go @@ -11,25 +11,25 @@ import ( "gitea.com/go-chi/session" ) -// Store represents a session store +type RawStore = session.RawStore + type Store interface { - Get(any) any - Set(any, any) error - Delete(any) error - ID() string - Release() error - Flush() error + RawStore Destroy(http.ResponseWriter, *http.Request) error } +type mockStoreContextKeyStruct struct{} + +var MockStoreContextKey = mockStoreContextKeyStruct{} + // RegenerateSession regenerates the underlying session and returns the new store func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) { for _, f := range BeforeRegenerateSession { f(resp, req) } if setting.IsInTesting { - if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok { - return store, nil + if store := req.Context().Value(MockStoreContextKey); store != nil { + return store.(Store), nil } } return session.RegenerateSession(resp, req) @@ -37,8 +37,8 @@ func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, erro func GetContextSession(req *http.Request) Store { if setting.IsInTesting { - if store, ok := req.Context().Value(MockStoreContextKey).(*MockStore); ok { - return store + if store := req.Context().Value(MockStoreContextKey); store != nil { + return store.(Store) } } return session.GetSession(req) diff --git a/modules/session/virtual.go b/modules/session/virtual.go index 80352b6e72..2e29b5fc6f 100644 --- a/modules/session/virtual.go +++ b/modules/session/virtual.go @@ -22,8 +22,8 @@ type VirtualSessionProvider struct { provider session.Provider } -// Init initializes the cookie session provider with given root path. -func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error { +// Init initializes the cookie session provider with the given config. +func (o *VirtualSessionProvider) Init(gcLifetime int64, config string) error { var opts session.Options if err := json.Unmarshal([]byte(config), &opts); err != nil { return err @@ -52,7 +52,7 @@ func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error { default: return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider) } - return o.provider.Init(gclifetime, opts.ProviderConfig) + return o.provider.Init(gcLifetime, opts.ProviderConfig) } // Read returns raw session store by session ID. |