diff options
author | wxiaoguang <wxiaoguang@gmail.com> | 2023-04-21 02:49:06 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-20 14:49:06 -0400 |
commit | b9a97ccd0ea1ee44db85b0fbb80b75255af7c742 (patch) | |
tree | 300578dc3c3e62a4cf956ccdc22b8b0ad0cc6036 /modules | |
parent | 70fc47a22a0bfaef7fb16dcc8a6a2e011b10f8d4 (diff) | |
download | gitea-b9a97ccd0ea1ee44db85b0fbb80b75255af7c742.tar.gz gitea-b9a97ccd0ea1ee44db85b0fbb80b75255af7c742.zip |
Refactor web route (#24080)
The old code is unnecessarily complex, and has many misuses.
Old code "wraps" a lot, wrap wrap wrap, it's difficult to understand
which kind of handler is used.
The new code uses a general approach, we do not need to write all kinds
of handlers into the "wrapper", do not need to wrap them again and
again.
New code, there are only 2 concepts:
1. HandlerProvider: `func (h any) (handlerProvider func (next)
http.Handler)`, it can be used as middleware
2. Use HandlerProvider to get the final HandlerFunc, and use it for
`r.Get()`
And we can decouple the route package from context package (see the
TODO).
# FAQ
## Is `reflect` safe?
Yes, all handlers are checked during startup, see the `preCheckHandler`
comment. If any handler is wrong, developers could know it in the first
time.
## Does `reflect` affect performance?
No. https://github.com/go-gitea/gitea/pull/24080#discussion_r1164825901
1. This reflect code only runs for each web handler call, handler is far
more slower: 10ms-50ms
2. The reflect is pretty fast (comparing to other code): 0.000265ms
3. XORM has more reflect operations already
Diffstat (limited to 'modules')
-rw-r--r-- | modules/web/handler.go | 200 | ||||
-rw-r--r-- | modules/web/route.go | 126 | ||||
-rw-r--r-- | modules/web/route_test.go | 48 | ||||
-rw-r--r-- | modules/web/wrap.go | 116 | ||||
-rw-r--r-- | modules/web/wrap_convert.go | 109 |
5 files changed, 282 insertions, 317 deletions
diff --git a/modules/web/handler.go b/modules/web/handler.go new file mode 100644 index 0000000000..8a44673f12 --- /dev/null +++ b/modules/web/handler.go @@ -0,0 +1,200 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + goctx "context" + "fmt" + "net/http" + "reflect" + "strings" + + "code.gitea.io/gitea/modules/context" + "code.gitea.io/gitea/modules/web/routing" +) + +// ResponseStatusProvider is an interface to check whether the response has been written by the handler +type ResponseStatusProvider interface { + Written() bool +} + +// TODO: decouple this from the context package, let the context package register these providers +var argTypeProvider = map[reflect.Type]func(req *http.Request) ResponseStatusProvider{ + reflect.TypeOf(&context.APIContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetAPIContext(req) }, + reflect.TypeOf(&context.Context{}): func(req *http.Request) ResponseStatusProvider { return context.GetContext(req) }, + reflect.TypeOf(&context.PrivateContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetPrivateContext(req) }, +} + +// responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written +type responseWriter struct { + respWriter http.ResponseWriter + status int +} + +var _ ResponseStatusProvider = (*responseWriter)(nil) + +func (r *responseWriter) Written() bool { + return r.status > 0 +} + +func (r *responseWriter) Header() http.Header { + return r.respWriter.Header() +} + +func (r *responseWriter) Write(bytes []byte) (int, error) { + if r.status == 0 { + r.status = http.StatusOK + } + return r.respWriter.Write(bytes) +} + +func (r *responseWriter) WriteHeader(statusCode int) { + r.status = statusCode + r.respWriter.WriteHeader(statusCode) +} + +var ( + httpReqType = reflect.TypeOf((*http.Request)(nil)) + respWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() + cancelFuncType = reflect.TypeOf((*goctx.CancelFunc)(nil)).Elem() +) + +// preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup +func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) { + hasStatusProvider := false + for _, argIn := range argsIn { + if _, hasStatusProvider = argIn.Interface().(ResponseStatusProvider); hasStatusProvider { + break + } + } + if !hasStatusProvider { + panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type())) + } + if fn.Type().NumOut() != 0 && fn.Type().NumIn() != 1 { + panic(fmt.Sprintf("handler should have no return value or only one argument, but got %s", fn.Type())) + } + if fn.Type().NumOut() == 1 && fn.Type().Out(0) != cancelFuncType { + panic(fmt.Sprintf("handler should return a cancel function, but got %s", fn.Type())) + } +} + +func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value) []reflect.Value { + isPreCheck := req == nil + + argsIn := make([]reflect.Value, fn.Type().NumIn()) + for i := 0; i < fn.Type().NumIn(); i++ { + argTyp := fn.Type().In(i) + switch argTyp { + case respWriterType: + argsIn[i] = reflect.ValueOf(resp) + case httpReqType: + argsIn[i] = reflect.ValueOf(req) + default: + if argFn, ok := argTypeProvider[argTyp]; ok { + if isPreCheck { + argsIn[i] = reflect.ValueOf(&responseWriter{}) + } else { + argsIn[i] = reflect.ValueOf(argFn(req)) + } + } else { + panic(fmt.Sprintf("unsupported argument type: %s", argTyp)) + } + } + } + return argsIn +} + +func handleResponse(fn reflect.Value, ret []reflect.Value) goctx.CancelFunc { + if len(ret) == 1 { + if cancelFunc, ok := ret[0].Interface().(goctx.CancelFunc); ok { + return cancelFunc + } + panic(fmt.Sprintf("unsupported return type: %s", ret[0].Type())) + } else if len(ret) > 1 { + panic(fmt.Sprintf("unsupported return values: %s", fn.Type())) + } + return nil +} + +func hasResponseBeenWritten(argsIn []reflect.Value) bool { + for _, argIn := range argsIn { + if statusProvider, ok := argIn.Interface().(ResponseStatusProvider); ok { + if statusProvider.Written() { + return true + } + } + } + return false +} + +// toHandlerProvider converts a handler to a handler provider +// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware +func toHandlerProvider(handler any) func(next http.Handler) http.Handler { + if hp, ok := handler.(func(next http.Handler) http.Handler); ok { + return hp + } + + funcInfo := routing.GetFuncInfo(handler) + fn := reflect.ValueOf(handler) + if fn.Type().Kind() != reflect.Func { + panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type())) + } + + provider := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) { + // wrap the response writer to check whether the response has been written + resp := respOrig + if _, ok := resp.(ResponseStatusProvider); !ok { + resp = &responseWriter{respWriter: resp} + } + + // prepare the arguments for the handler and do pre-check + argsIn := prepareHandleArgsIn(resp, req, fn) + if req == nil { + preCheckHandler(fn, argsIn) + return // it's doing pre-check, just return + } + + routing.UpdateFuncInfo(req.Context(), funcInfo) + ret := fn.Call(argsIn) + + // handle the return value, and defer the cancel function if there is one + cancelFunc := handleResponse(fn, ret) + if cancelFunc != nil { + defer cancelFunc() + } + + // if the response has not been written, call the next handler + if next != nil && !hasResponseBeenWritten(argsIn) { + next.ServeHTTP(resp, req) + } + }) + } + + provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported + return provider +} + +// MiddlewareWithPrefix wraps a handler function at a prefix, and make it as a middleware +// TODO: this design is incorrect, the asset handler should not be a middleware +func MiddlewareWithPrefix(pathPrefix string, middleware func(handler http.Handler) http.Handler, handlerFunc http.HandlerFunc) func(next http.Handler) http.Handler { + funcInfo := routing.GetFuncInfo(handlerFunc) + handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + routing.UpdateFuncInfo(req.Context(), funcInfo) + handlerFunc(resp, req) + }) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + if !strings.HasPrefix(req.URL.Path, pathPrefix) { + next.ServeHTTP(resp, req) + return + } + if middleware != nil { + middleware(handler).ServeHTTP(resp, req) + } else { + handler.ServeHTTP(resp, req) + } + }) + } +} diff --git a/modules/web/route.go b/modules/web/route.go index 0f2fdc33b5..fe35880849 100644 --- a/modules/web/route.go +++ b/modules/web/route.go @@ -4,8 +4,6 @@ package web import ( - goctx "context" - "fmt" "net/http" "strings" @@ -17,13 +15,13 @@ import ( ) // Bind binding an obj to a handler -func Bind[T any](obj T) http.HandlerFunc { - return Wrap(func(ctx *context.Context) { +func Bind[T any](_ T) any { + return func(ctx *context.Context) { theObj := new(T) // create a new form obj for every request but not use obj directly binding.Bind(ctx.Req, theObj) SetForm(ctx, theObj) middleware.AssignForm(theObj, ctx.Data) - }) + } } // SetForm set the form object @@ -56,21 +54,12 @@ func NewRoute() *Route { // Use supports two middlewares func (r *Route) Use(middlewares ...interface{}) { if r.curGroupPrefix != "" { + // FIXME: this behavior is incorrect, should use "With" instead r.curMiddlewares = append(r.curMiddlewares, middlewares...) } else { - for _, middle := range middlewares { - switch t := middle.(type) { - case func(http.Handler) http.Handler: - r.R.Use(t) - case func(*context.Context): - r.R.Use(Middle(t)) - case func(*context.Context) goctx.CancelFunc: - r.R.Use(MiddleCancel(t)) - case func(*context.APIContext): - r.R.Use(MiddleAPI(t)) - default: - panic(fmt.Sprintf("Unsupported middleware type: %#v", t)) - } + // FIXME: another misuse, the "Use" with empty middlewares is called after "Mount" + for _, m := range middlewares { + r.R.Use(toHandlerProvider(m)) } } } @@ -99,6 +88,32 @@ func (r *Route) getPattern(pattern string) string { return strings.TrimSuffix(newPattern, "/") } +func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) { + handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)) + for _, m := range r.curMiddlewares { + handlerProviders = append(handlerProviders, toHandlerProvider(m)) + } + for _, m := range h { + handlerProviders = append(handlerProviders, toHandlerProvider(m)) + } + middlewares := handlerProviders[:len(handlerProviders)-1] + handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP + return middlewares, handlerFunc +} + +func (r *Route) Methods(method, pattern string, h []any) { + middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h) + fullPattern := r.getPattern(pattern) + if strings.Contains(method, ",") { + methods := strings.Split(method, ",") + for _, method := range methods { + r.R.With(middlewares...).Method(strings.TrimSpace(method), fullPattern, handlerFunc) + } + } else { + r.R.With(middlewares...).Method(method, fullPattern, handlerFunc) + } +} + // Mount attaches another Route along ./pattern/* func (r *Route) Mount(pattern string, subR *Route) { middlewares := make([]interface{}, len(r.curMiddlewares)) @@ -109,81 +124,53 @@ func (r *Route) Mount(pattern string, subR *Route) { // Any delegate requests for all methods func (r *Route) Any(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.HandleFunc(r.getPattern(pattern), Wrap(middlewares...)) + middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h) + r.R.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc) } -// Route delegate special methods -func (r *Route) Route(pattern, methods string, h ...interface{}) { - p := r.getPattern(pattern) - ms := strings.Split(methods, ",") - middlewares := r.getMiddlewares(h) - for _, method := range ms { - r.R.MethodFunc(strings.TrimSpace(method), p, Wrap(middlewares...)) - } +// RouteMethods delegate special methods, it is an alias of "Methods", while the "pattern" is the first parameter +func (r *Route) RouteMethods(pattern, methods string, h ...interface{}) { + r.Methods(methods, pattern, h) } // Delete delegate delete method func (r *Route) Delete(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Delete(r.getPattern(pattern), Wrap(middlewares...)) -} - -func (r *Route) getMiddlewares(h []interface{}) []interface{} { - middlewares := make([]interface{}, len(r.curMiddlewares), len(r.curMiddlewares)+len(h)) - copy(middlewares, r.curMiddlewares) - middlewares = append(middlewares, h...) - return middlewares + r.Methods("DELETE", pattern, h) } // Get delegate get method func (r *Route) Get(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Get(r.getPattern(pattern), Wrap(middlewares...)) -} - -// Options delegate options method -func (r *Route) Options(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Options(r.getPattern(pattern), Wrap(middlewares...)) + r.Methods("GET", pattern, h) } // GetOptions delegate get and options method func (r *Route) GetOptions(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Get(r.getPattern(pattern), Wrap(middlewares...)) - r.R.Options(r.getPattern(pattern), Wrap(middlewares...)) + r.Methods("GET,OPTIONS", pattern, h) } // PostOptions delegate post and options method func (r *Route) PostOptions(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Post(r.getPattern(pattern), Wrap(middlewares...)) - r.R.Options(r.getPattern(pattern), Wrap(middlewares...)) + r.Methods("POST,OPTIONS", pattern, h) } // Head delegate head method func (r *Route) Head(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Head(r.getPattern(pattern), Wrap(middlewares...)) + r.Methods("HEAD", pattern, h) } // Post delegate post method func (r *Route) Post(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Post(r.getPattern(pattern), Wrap(middlewares...)) + r.Methods("POST", pattern, h) } // Put delegate put method func (r *Route) Put(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Put(r.getPattern(pattern), Wrap(middlewares...)) + r.Methods("PUT", pattern, h) } // Patch delegate patch method func (r *Route) Patch(pattern string, h ...interface{}) { - middlewares := r.getMiddlewares(h) - r.R.Patch(r.getPattern(pattern), Wrap(middlewares...)) + r.Methods("PATCH", pattern, h) } // ServeHTTP implements http.Handler @@ -191,19 +178,12 @@ func (r *Route) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.R.ServeHTTP(w, req) } -// NotFound defines a handler to respond whenever a route could -// not be found. +// NotFound defines a handler to respond whenever a route could not be found. func (r *Route) NotFound(h http.HandlerFunc) { r.R.NotFound(h) } -// MethodNotAllowed defines a handler to respond whenever a method is -// not allowed. -func (r *Route) MethodNotAllowed(h http.HandlerFunc) { - r.R.MethodNotAllowed(h) -} - -// Combo deletegate requests to Combo +// Combo delegates requests to Combo func (r *Route) Combo(pattern string, h ...interface{}) *Combo { return &Combo{r, pattern, h} } @@ -215,31 +195,31 @@ type Combo struct { h []interface{} } -// Get deletegate Get method +// Get delegates Get method func (c *Combo) Get(h ...interface{}) *Combo { c.r.Get(c.pattern, append(c.h, h...)...) return c } -// Post deletegate Post method +// Post delegates Post method func (c *Combo) Post(h ...interface{}) *Combo { c.r.Post(c.pattern, append(c.h, h...)...) return c } -// Delete deletegate Delete method +// Delete delegates Delete method func (c *Combo) Delete(h ...interface{}) *Combo { c.r.Delete(c.pattern, append(c.h, h...)...) return c } -// Put deletegate Put method +// Put delegates Put method func (c *Combo) Put(h ...interface{}) *Combo { c.r.Put(c.pattern, append(c.h, h...)...) return c } -// Patch deletegate Patch method +// Patch delegates Patch method func (c *Combo) Patch(h ...interface{}) *Combo { c.r.Patch(c.pattern, append(c.h, h...)...) return c diff --git a/modules/web/route_test.go b/modules/web/route_test.go index 232444cb83..cc0e26a12e 100644 --- a/modules/web/route_test.go +++ b/modules/web/route_test.go @@ -7,6 +7,7 @@ import ( "bytes" "net/http" "net/http/httptest" + "strconv" "testing" chi "github.com/go-chi/chi/v5" @@ -39,7 +40,7 @@ func TestRoute2(t *testing.T) { recorder := httptest.NewRecorder() recorder.Body = buff - var route int + hit := -1 r := NewRoute() r.Group("/{username}/{reponame}", func() { @@ -51,7 +52,7 @@ func TestRoute2(t *testing.T) { assert.EqualValues(t, "gitea", reponame) tp := chi.URLParam(req, "type") assert.EqualValues(t, "issues", tp) - route = 0 + hit = 0 }) r.Get("/{type:issues|pulls}/{index}", func(resp http.ResponseWriter, req *http.Request) { @@ -63,10 +64,13 @@ func TestRoute2(t *testing.T) { assert.EqualValues(t, "issues", tp) index := chi.URLParam(req, "index") assert.EqualValues(t, "1", index) - route = 1 + hit = 1 }) }, func(resp http.ResponseWriter, req *http.Request) { - resp.WriteHeader(http.StatusOK) + if stop, err := strconv.Atoi(req.FormValue("stop")); err == nil { + hit = stop + resp.WriteHeader(http.StatusOK) + } }) r.Group("/issues/{index}", func() { @@ -77,7 +81,7 @@ func TestRoute2(t *testing.T) { assert.EqualValues(t, "gitea", reponame) index := chi.URLParam(req, "index") assert.EqualValues(t, "1", index) - route = 2 + hit = 2 }) }) }) @@ -86,19 +90,25 @@ func TestRoute2(t *testing.T) { assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code) - assert.EqualValues(t, 0, route) + assert.EqualValues(t, 0, hit) req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1", nil) assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code) - assert.EqualValues(t, 1, route) + assert.EqualValues(t, 1, hit) + + req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1?stop=100", nil) + assert.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 100, hit) req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1/view", nil) assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code) - assert.EqualValues(t, 2, route) + assert.EqualValues(t, 2, hit) } func TestRoute3(t *testing.T) { @@ -106,7 +116,7 @@ func TestRoute3(t *testing.T) { recorder := httptest.NewRecorder() recorder.Body = buff - var route int + hit := -1 m := NewRoute() r := NewRoute() @@ -116,20 +126,20 @@ func TestRoute3(t *testing.T) { m.Group("/{username}/{reponame}", func() { m.Group("/branch_protections", func() { m.Get("", func(resp http.ResponseWriter, req *http.Request) { - route = 0 + hit = 0 }) m.Post("", func(resp http.ResponseWriter, req *http.Request) { - route = 1 + hit = 1 }) m.Group("/{name}", func() { m.Get("", func(resp http.ResponseWriter, req *http.Request) { - route = 2 + hit = 2 }) m.Patch("", func(resp http.ResponseWriter, req *http.Request) { - route = 3 + hit = 3 }) m.Delete("", func(resp http.ResponseWriter, req *http.Request) { - route = 4 + hit = 4 }) }) }) @@ -140,29 +150,29 @@ func TestRoute3(t *testing.T) { assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code) - assert.EqualValues(t, 0, route) + assert.EqualValues(t, 0, hit) req, err = http.NewRequest("POST", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code, http.StatusOK) - assert.EqualValues(t, 1, route) + assert.EqualValues(t, 1, hit) req, err = http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code) - assert.EqualValues(t, 2, route) + assert.EqualValues(t, 2, hit) req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code) - assert.EqualValues(t, 3, route) + assert.EqualValues(t, 3, hit) req, err = http.NewRequest("DELETE", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) assert.NoError(t, err) r.ServeHTTP(recorder, req) assert.EqualValues(t, http.StatusOK, recorder.Code) - assert.EqualValues(t, 4, route) + assert.EqualValues(t, 4, hit) } diff --git a/modules/web/wrap.go b/modules/web/wrap.go deleted file mode 100644 index 0ff9529fae..0000000000 --- a/modules/web/wrap.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2021 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package web - -import ( - goctx "context" - "net/http" - "strings" - - "code.gitea.io/gitea/modules/context" - "code.gitea.io/gitea/modules/web/routing" -) - -// Wrap converts all kinds of routes to standard library one -func Wrap(handlers ...interface{}) http.HandlerFunc { - if len(handlers) == 0 { - panic("No handlers found") - } - - ourHandlers := make([]wrappedHandlerFunc, 0, len(handlers)) - - for _, handler := range handlers { - ourHandlers = append(ourHandlers, convertHandler(handler)) - } - return wrapInternal(ourHandlers) -} - -func wrapInternal(handlers []wrappedHandlerFunc) http.HandlerFunc { - return func(resp http.ResponseWriter, req *http.Request) { - var defers []func() - defer func() { - for i := len(defers) - 1; i >= 0; i-- { - defers[i]() - } - }() - for i := 0; i < len(handlers); i++ { - handler := handlers[i] - others := handlers[i+1:] - done, deferrable := handler(resp, req, others...) - if deferrable != nil { - defers = append(defers, deferrable) - } - if done { - return - } - } - } -} - -// Middle wrap a context function as a chi middleware -func Middle(f func(ctx *context.Context)) func(next http.Handler) http.Handler { - funcInfo := routing.GetFuncInfo(f) - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetContext(req) - f(ctx) - if ctx.Written() { - return - } - next.ServeHTTP(ctx.Resp, ctx.Req) - }) - } -} - -// MiddleCancel wrap a context function as a chi middleware -func MiddleCancel(f func(ctx *context.Context) goctx.CancelFunc) func(netx http.Handler) http.Handler { - funcInfo := routing.GetFuncInfo(f) - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetContext(req) - cancel := f(ctx) - if cancel != nil { - defer cancel() - } - if ctx.Written() { - return - } - next.ServeHTTP(ctx.Resp, ctx.Req) - }) - } -} - -// MiddleAPI wrap a context function as a chi middleware -func MiddleAPI(f func(ctx *context.APIContext)) func(next http.Handler) http.Handler { - funcInfo := routing.GetFuncInfo(f) - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetAPIContext(req) - f(ctx) - if ctx.Written() { - return - } - next.ServeHTTP(ctx.Resp, ctx.Req) - }) - } -} - -// WrapWithPrefix wraps a provided handler function at a prefix -func WrapWithPrefix(pathPrefix string, handler http.HandlerFunc, friendlyName ...string) func(next http.Handler) http.Handler { - funcInfo := routing.GetFuncInfo(handler, friendlyName...) - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - if !strings.HasPrefix(req.URL.Path, pathPrefix) { - next.ServeHTTP(resp, req) - return - } - routing.UpdateFuncInfo(req.Context(), funcInfo) - handler(resp, req) - }) - } -} diff --git a/modules/web/wrap_convert.go b/modules/web/wrap_convert.go deleted file mode 100644 index 6778e208cf..0000000000 --- a/modules/web/wrap_convert.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2021 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package web - -import ( - goctx "context" - "fmt" - "net/http" - - "code.gitea.io/gitea/modules/context" - "code.gitea.io/gitea/modules/web/routing" -) - -type wrappedHandlerFunc func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) - -func convertHandler(handler interface{}) wrappedHandlerFunc { - funcInfo := routing.GetFuncInfo(handler) - switch t := handler.(type) { - case http.HandlerFunc: - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - if _, ok := resp.(context.ResponseWriter); !ok { - resp = context.NewResponse(resp) - } - t(resp, req) - if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { - done = true - } - return done, deferrable - } - case func(http.ResponseWriter, *http.Request): - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - t(resp, req) - if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { - done = true - } - return done, deferrable - } - - case func(ctx *context.Context): - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetContext(req) - t(ctx) - done = ctx.Written() - return done, deferrable - } - case func(ctx *context.Context) goctx.CancelFunc: - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetContext(req) - deferrable = t(ctx) - done = ctx.Written() - return done, deferrable - } - case func(*context.APIContext): - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetAPIContext(req) - t(ctx) - done = ctx.Written() - return done, deferrable - } - case func(*context.APIContext) goctx.CancelFunc: - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetAPIContext(req) - deferrable = t(ctx) - done = ctx.Written() - return done, deferrable - } - case func(*context.PrivateContext): - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetPrivateContext(req) - t(ctx) - done = ctx.Written() - return done, deferrable - } - case func(*context.PrivateContext) goctx.CancelFunc: - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - routing.UpdateFuncInfo(req.Context(), funcInfo) - ctx := context.GetPrivateContext(req) - deferrable = t(ctx) - done = ctx.Written() - return done, deferrable - } - case func(http.Handler) http.Handler: - return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { - next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) - if len(others) > 0 { - next = wrapInternal(others) - } - routing.UpdateFuncInfo(req.Context(), funcInfo) - if _, ok := resp.(context.ResponseWriter); !ok { - resp = context.NewResponse(resp) - } - t(next).ServeHTTP(resp, req) - if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { - done = true - } - return done, deferrable - } - default: - panic(fmt.Sprintf("Unsupported handler type: %#v", t)) - } -} |