To mock a handler: ```go web.RouteMock(web.MockAfterMiddlewares, func(ctx *context.Context) { // ... }) defer web.RouteMockReset() ``` It helps: * Test the middleware's behavior (assert the ctx.Data, etc) * Mock the middleware's behavior (prepare some context data for handler) * Mock the handler's response for some test cases, especially for some integration tests and e2e tests.tags/v1.21.0-rc0
@@ -50,7 +50,9 @@ func NewRoute() *Route { | |||
// Use supports two middlewares | |||
func (r *Route) Use(middlewares ...any) { | |||
for _, m := range middlewares { | |||
r.R.Use(toHandlerProvider(m)) | |||
if m != nil { | |||
r.R.Use(toHandlerProvider(m)) | |||
} | |||
} | |||
} | |||
@@ -79,15 +81,23 @@ func (r *Route) getPattern(pattern string) string { | |||
} | |||
func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) { | |||
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)) | |||
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)+1) | |||
for _, m := range r.curMiddlewares { | |||
handlerProviders = append(handlerProviders, toHandlerProvider(m)) | |||
if m != nil { | |||
handlerProviders = append(handlerProviders, toHandlerProvider(m)) | |||
} | |||
} | |||
for _, m := range h { | |||
handlerProviders = append(handlerProviders, toHandlerProvider(m)) | |||
if h != nil { | |||
handlerProviders = append(handlerProviders, toHandlerProvider(m)) | |||
} | |||
} | |||
middlewares := handlerProviders[:len(handlerProviders)-1] | |||
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP | |||
mockPoint := RouteMockPoint(MockAfterMiddlewares) | |||
if mockPoint != nil { | |||
middlewares = append(middlewares, mockPoint) | |||
} | |||
return middlewares, handlerFunc | |||
} | |||
@@ -0,0 +1,61 @@ | |||
// Copyright 2023 The Gitea Authors. All rights reserved. | |||
// SPDX-License-Identifier: MIT | |||
package web | |||
import ( | |||
"net/http" | |||
"code.gitea.io/gitea/modules/setting" | |||
) | |||
// MockAfterMiddlewares is a general mock point, it's between middlewares and the handler | |||
const MockAfterMiddlewares = "MockAfterMiddlewares" | |||
var routeMockPoints = map[string]func(next http.Handler) http.Handler{} | |||
// RouteMockPoint registers a mock point as a middleware for testing, example: | |||
// | |||
// r.Use(web.RouteMockPoint("my-mock-point-1")) | |||
// r.Get("/foo", middleware2, web.RouteMockPoint("my-mock-point-2"), middleware2, handler) | |||
// | |||
// Then use web.RouteMock to mock the route execution. | |||
// It only takes effect in testing mode (setting.IsInTesting == true). | |||
func RouteMockPoint(pointName string) func(next http.Handler) http.Handler { | |||
if !setting.IsInTesting { | |||
return nil | |||
} | |||
routeMockPoints[pointName] = nil | |||
return func(next http.Handler) http.Handler { | |||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
if h := routeMockPoints[pointName]; h != nil { | |||
h(next).ServeHTTP(w, r) | |||
} else { | |||
next.ServeHTTP(w, r) | |||
} | |||
}) | |||
} | |||
} | |||
// RouteMock uses the registered mock point to mock the route execution, example: | |||
// | |||
// defer web.RouteMockReset() | |||
// web.RouteMock(web.MockAfterMiddlewares, func(ctx *context.Context) { | |||
// ctx.WriteResponse(...) | |||
// } | |||
// | |||
// Then the mock function will be executed as a middleware at the mock point. | |||
// It only takes effect in testing mode (setting.IsInTesting == true). | |||
func RouteMock(pointName string, h any) { | |||
if _, ok := routeMockPoints[pointName]; !ok { | |||
panic("route mock point not found: " + pointName) | |||
} | |||
routeMockPoints[pointName] = toHandlerProvider(h) | |||
} | |||
// RouteMockReset resets all mock points (no mock anymore) | |||
func RouteMockReset() { | |||
for k := range routeMockPoints { | |||
routeMockPoints[k] = nil // keep the keys because RouteMock will check the keys to make sure no misspelling | |||
} | |||
} |
@@ -0,0 +1,70 @@ | |||
// Copyright 2023 The Gitea Authors. All rights reserved. | |||
// SPDX-License-Identifier: MIT | |||
package web | |||
import ( | |||
"net/http" | |||
"net/http/httptest" | |||
"testing" | |||
"code.gitea.io/gitea/modules/setting" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestRouteMock(t *testing.T) { | |||
setting.IsInTesting = true | |||
r := NewRoute() | |||
middleware1 := func(resp http.ResponseWriter, req *http.Request) { | |||
resp.Header().Set("X-Test-Middleware1", "m1") | |||
} | |||
middleware2 := func(resp http.ResponseWriter, req *http.Request) { | |||
resp.Header().Set("X-Test-Middleware2", "m2") | |||
} | |||
handler := func(resp http.ResponseWriter, req *http.Request) { | |||
resp.Header().Set("X-Test-Handler", "h") | |||
} | |||
r.Get("/foo", middleware1, RouteMockPoint("mock-point"), middleware2, handler) | |||
// normal request | |||
recorder := httptest.NewRecorder() | |||
req, err := http.NewRequest("GET", "http://localhost:8000/foo", nil) | |||
assert.NoError(t, err) | |||
r.ServeHTTP(recorder, req) | |||
assert.Len(t, recorder.Header(), 3) | |||
assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) | |||
assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2")) | |||
assert.EqualValues(t, "h", recorder.Header().Get("X-Test-Handler")) | |||
RouteMockReset() | |||
// mock at "mock-point" | |||
RouteMock("mock-point", func(resp http.ResponseWriter, req *http.Request) { | |||
resp.Header().Set("X-Test-MockPoint", "a") | |||
resp.WriteHeader(http.StatusOK) | |||
}) | |||
recorder = httptest.NewRecorder() | |||
req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil) | |||
assert.NoError(t, err) | |||
r.ServeHTTP(recorder, req) | |||
assert.Len(t, recorder.Header(), 2) | |||
assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) | |||
assert.EqualValues(t, "a", recorder.Header().Get("X-Test-MockPoint")) | |||
RouteMockReset() | |||
// mock at MockAfterMiddlewares | |||
RouteMock(MockAfterMiddlewares, func(resp http.ResponseWriter, req *http.Request) { | |||
resp.Header().Set("X-Test-MockPoint", "b") | |||
resp.WriteHeader(http.StatusOK) | |||
}) | |||
recorder = httptest.NewRecorder() | |||
req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil) | |||
assert.NoError(t, err) | |||
r.ServeHTTP(recorder, req) | |||
assert.Len(t, recorder.Header(), 3) | |||
assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) | |||
assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2")) | |||
assert.EqualValues(t, "b", recorder.Header().Get("X-Test-MockPoint")) | |||
RouteMockReset() | |||
} |