The old code is unnecessarily complex, and has many misuses. Old code "wraps" a lot, wrap wrap wrap, it's difficult to understand which kind of handler is used. The new code uses a general approach, we do not need to write all kinds of handlers into the "wrapper", do not need to wrap them again and again. New code, there are only 2 concepts: 1. HandlerProvider: `func (h any) (handlerProvider func (next) http.Handler)`, it can be used as middleware 2. Use HandlerProvider to get the final HandlerFunc, and use it for `r.Get()` And we can decouple the route package from context package (see the TODO). # FAQ ## Is `reflect` safe? Yes, all handlers are checked during startup, see the `preCheckHandler` comment. If any handler is wrong, developers could know it in the first time. ## Does `reflect` affect performance? No. https://github.com/go-gitea/gitea/pull/24080#discussion_r1164825901 1. This reflect code only runs for each web handler call, handler is far more slower: 10ms-50ms 2. The reflect is pretty fast (comparing to other code): 0.000265ms 3. XORM has more reflect operations alreadytags/v1.20.0-rc0
// Copyright 2023 The Gitea Authors. All rights reserved. | |||||
// SPDX-License-Identifier: MIT | |||||
package web | |||||
import ( | |||||
goctx "context" | |||||
"fmt" | |||||
"net/http" | |||||
"reflect" | |||||
"strings" | |||||
"code.gitea.io/gitea/modules/context" | |||||
"code.gitea.io/gitea/modules/web/routing" | |||||
) | |||||
// ResponseStatusProvider is an interface to check whether the response has been written by the handler | |||||
type ResponseStatusProvider interface { | |||||
Written() bool | |||||
} | |||||
// TODO: decouple this from the context package, let the context package register these providers | |||||
var argTypeProvider = map[reflect.Type]func(req *http.Request) ResponseStatusProvider{ | |||||
reflect.TypeOf(&context.APIContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetAPIContext(req) }, | |||||
reflect.TypeOf(&context.Context{}): func(req *http.Request) ResponseStatusProvider { return context.GetContext(req) }, | |||||
reflect.TypeOf(&context.PrivateContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetPrivateContext(req) }, | |||||
} | |||||
// responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written | |||||
type responseWriter struct { | |||||
respWriter http.ResponseWriter | |||||
status int | |||||
} | |||||
var _ ResponseStatusProvider = (*responseWriter)(nil) | |||||
func (r *responseWriter) Written() bool { | |||||
return r.status > 0 | |||||
} | |||||
func (r *responseWriter) Header() http.Header { | |||||
return r.respWriter.Header() | |||||
} | |||||
func (r *responseWriter) Write(bytes []byte) (int, error) { | |||||
if r.status == 0 { | |||||
r.status = http.StatusOK | |||||
} | |||||
return r.respWriter.Write(bytes) | |||||
} | |||||
func (r *responseWriter) WriteHeader(statusCode int) { | |||||
r.status = statusCode | |||||
r.respWriter.WriteHeader(statusCode) | |||||
} | |||||
var ( | |||||
httpReqType = reflect.TypeOf((*http.Request)(nil)) | |||||
respWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() | |||||
cancelFuncType = reflect.TypeOf((*goctx.CancelFunc)(nil)).Elem() | |||||
) | |||||
// preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup | |||||
func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) { | |||||
hasStatusProvider := false | |||||
for _, argIn := range argsIn { | |||||
if _, hasStatusProvider = argIn.Interface().(ResponseStatusProvider); hasStatusProvider { | |||||
break | |||||
} | |||||
} | |||||
if !hasStatusProvider { | |||||
panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type())) | |||||
} | |||||
if fn.Type().NumOut() != 0 && fn.Type().NumIn() != 1 { | |||||
panic(fmt.Sprintf("handler should have no return value or only one argument, but got %s", fn.Type())) | |||||
} | |||||
if fn.Type().NumOut() == 1 && fn.Type().Out(0) != cancelFuncType { | |||||
panic(fmt.Sprintf("handler should return a cancel function, but got %s", fn.Type())) | |||||
} | |||||
} | |||||
func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value) []reflect.Value { | |||||
isPreCheck := req == nil | |||||
argsIn := make([]reflect.Value, fn.Type().NumIn()) | |||||
for i := 0; i < fn.Type().NumIn(); i++ { | |||||
argTyp := fn.Type().In(i) | |||||
switch argTyp { | |||||
case respWriterType: | |||||
argsIn[i] = reflect.ValueOf(resp) | |||||
case httpReqType: | |||||
argsIn[i] = reflect.ValueOf(req) | |||||
default: | |||||
if argFn, ok := argTypeProvider[argTyp]; ok { | |||||
if isPreCheck { | |||||
argsIn[i] = reflect.ValueOf(&responseWriter{}) | |||||
} else { | |||||
argsIn[i] = reflect.ValueOf(argFn(req)) | |||||
} | |||||
} else { | |||||
panic(fmt.Sprintf("unsupported argument type: %s", argTyp)) | |||||
} | |||||
} | |||||
} | |||||
return argsIn | |||||
} | |||||
func handleResponse(fn reflect.Value, ret []reflect.Value) goctx.CancelFunc { | |||||
if len(ret) == 1 { | |||||
if cancelFunc, ok := ret[0].Interface().(goctx.CancelFunc); ok { | |||||
return cancelFunc | |||||
} | |||||
panic(fmt.Sprintf("unsupported return type: %s", ret[0].Type())) | |||||
} else if len(ret) > 1 { | |||||
panic(fmt.Sprintf("unsupported return values: %s", fn.Type())) | |||||
} | |||||
return nil | |||||
} | |||||
func hasResponseBeenWritten(argsIn []reflect.Value) bool { | |||||
for _, argIn := range argsIn { | |||||
if statusProvider, ok := argIn.Interface().(ResponseStatusProvider); ok { | |||||
if statusProvider.Written() { | |||||
return true | |||||
} | |||||
} | |||||
} | |||||
return false | |||||
} | |||||
// toHandlerProvider converts a handler to a handler provider | |||||
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware | |||||
func toHandlerProvider(handler any) func(next http.Handler) http.Handler { | |||||
if hp, ok := handler.(func(next http.Handler) http.Handler); ok { | |||||
return hp | |||||
} | |||||
funcInfo := routing.GetFuncInfo(handler) | |||||
fn := reflect.ValueOf(handler) | |||||
if fn.Type().Kind() != reflect.Func { | |||||
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type())) | |||||
} | |||||
provider := func(next http.Handler) http.Handler { | |||||
return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) { | |||||
// wrap the response writer to check whether the response has been written | |||||
resp := respOrig | |||||
if _, ok := resp.(ResponseStatusProvider); !ok { | |||||
resp = &responseWriter{respWriter: resp} | |||||
} | |||||
// prepare the arguments for the handler and do pre-check | |||||
argsIn := prepareHandleArgsIn(resp, req, fn) | |||||
if req == nil { | |||||
preCheckHandler(fn, argsIn) | |||||
return // it's doing pre-check, just return | |||||
} | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ret := fn.Call(argsIn) | |||||
// handle the return value, and defer the cancel function if there is one | |||||
cancelFunc := handleResponse(fn, ret) | |||||
if cancelFunc != nil { | |||||
defer cancelFunc() | |||||
} | |||||
// if the response has not been written, call the next handler | |||||
if next != nil && !hasResponseBeenWritten(argsIn) { | |||||
next.ServeHTTP(resp, req) | |||||
} | |||||
}) | |||||
} | |||||
provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported | |||||
return provider | |||||
} | |||||
// MiddlewareWithPrefix wraps a handler function at a prefix, and make it as a middleware | |||||
// TODO: this design is incorrect, the asset handler should not be a middleware | |||||
func MiddlewareWithPrefix(pathPrefix string, middleware func(handler http.Handler) http.Handler, handlerFunc http.HandlerFunc) func(next http.Handler) http.Handler { | |||||
funcInfo := routing.GetFuncInfo(handlerFunc) | |||||
handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
handlerFunc(resp, req) | |||||
}) | |||||
return func(next http.Handler) http.Handler { | |||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | |||||
if !strings.HasPrefix(req.URL.Path, pathPrefix) { | |||||
next.ServeHTTP(resp, req) | |||||
return | |||||
} | |||||
if middleware != nil { | |||||
middleware(handler).ServeHTTP(resp, req) | |||||
} else { | |||||
handler.ServeHTTP(resp, req) | |||||
} | |||||
}) | |||||
} | |||||
} |
package web | package web | ||||
import ( | import ( | ||||
goctx "context" | |||||
"fmt" | |||||
"net/http" | "net/http" | ||||
"strings" | "strings" | ||||
) | ) | ||||
// Bind binding an obj to a handler | // Bind binding an obj to a handler | ||||
func Bind[T any](obj T) http.HandlerFunc { | |||||
return Wrap(func(ctx *context.Context) { | |||||
func Bind[T any](_ T) any { | |||||
return func(ctx *context.Context) { | |||||
theObj := new(T) // create a new form obj for every request but not use obj directly | theObj := new(T) // create a new form obj for every request but not use obj directly | ||||
binding.Bind(ctx.Req, theObj) | binding.Bind(ctx.Req, theObj) | ||||
SetForm(ctx, theObj) | SetForm(ctx, theObj) | ||||
middleware.AssignForm(theObj, ctx.Data) | middleware.AssignForm(theObj, ctx.Data) | ||||
}) | |||||
} | |||||
} | } | ||||
// SetForm set the form object | // SetForm set the form object | ||||
// Use supports two middlewares | // Use supports two middlewares | ||||
func (r *Route) Use(middlewares ...interface{}) { | func (r *Route) Use(middlewares ...interface{}) { | ||||
if r.curGroupPrefix != "" { | if r.curGroupPrefix != "" { | ||||
// FIXME: this behavior is incorrect, should use "With" instead | |||||
r.curMiddlewares = append(r.curMiddlewares, middlewares...) | r.curMiddlewares = append(r.curMiddlewares, middlewares...) | ||||
} else { | } else { | ||||
for _, middle := range middlewares { | |||||
switch t := middle.(type) { | |||||
case func(http.Handler) http.Handler: | |||||
r.R.Use(t) | |||||
case func(*context.Context): | |||||
r.R.Use(Middle(t)) | |||||
case func(*context.Context) goctx.CancelFunc: | |||||
r.R.Use(MiddleCancel(t)) | |||||
case func(*context.APIContext): | |||||
r.R.Use(MiddleAPI(t)) | |||||
default: | |||||
panic(fmt.Sprintf("Unsupported middleware type: %#v", t)) | |||||
} | |||||
// FIXME: another misuse, the "Use" with empty middlewares is called after "Mount" | |||||
for _, m := range middlewares { | |||||
r.R.Use(toHandlerProvider(m)) | |||||
} | } | ||||
} | } | ||||
} | } | ||||
return strings.TrimSuffix(newPattern, "/") | return strings.TrimSuffix(newPattern, "/") | ||||
} | } | ||||
func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) { | |||||
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)) | |||||
for _, m := range r.curMiddlewares { | |||||
handlerProviders = append(handlerProviders, toHandlerProvider(m)) | |||||
} | |||||
for _, m := range h { | |||||
handlerProviders = append(handlerProviders, toHandlerProvider(m)) | |||||
} | |||||
middlewares := handlerProviders[:len(handlerProviders)-1] | |||||
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP | |||||
return middlewares, handlerFunc | |||||
} | |||||
func (r *Route) Methods(method, pattern string, h []any) { | |||||
middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h) | |||||
fullPattern := r.getPattern(pattern) | |||||
if strings.Contains(method, ",") { | |||||
methods := strings.Split(method, ",") | |||||
for _, method := range methods { | |||||
r.R.With(middlewares...).Method(strings.TrimSpace(method), fullPattern, handlerFunc) | |||||
} | |||||
} else { | |||||
r.R.With(middlewares...).Method(method, fullPattern, handlerFunc) | |||||
} | |||||
} | |||||
// Mount attaches another Route along ./pattern/* | // Mount attaches another Route along ./pattern/* | ||||
func (r *Route) Mount(pattern string, subR *Route) { | func (r *Route) Mount(pattern string, subR *Route) { | ||||
middlewares := make([]interface{}, len(r.curMiddlewares)) | middlewares := make([]interface{}, len(r.curMiddlewares)) | ||||
// Any delegate requests for all methods | // Any delegate requests for all methods | ||||
func (r *Route) Any(pattern string, h ...interface{}) { | func (r *Route) Any(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.HandleFunc(r.getPattern(pattern), Wrap(middlewares...)) | |||||
middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h) | |||||
r.R.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc) | |||||
} | } | ||||
// Route delegate special methods | |||||
func (r *Route) Route(pattern, methods string, h ...interface{}) { | |||||
p := r.getPattern(pattern) | |||||
ms := strings.Split(methods, ",") | |||||
middlewares := r.getMiddlewares(h) | |||||
for _, method := range ms { | |||||
r.R.MethodFunc(strings.TrimSpace(method), p, Wrap(middlewares...)) | |||||
} | |||||
// RouteMethods delegate special methods, it is an alias of "Methods", while the "pattern" is the first parameter | |||||
func (r *Route) RouteMethods(pattern, methods string, h ...interface{}) { | |||||
r.Methods(methods, pattern, h) | |||||
} | } | ||||
// Delete delegate delete method | // Delete delegate delete method | ||||
func (r *Route) Delete(pattern string, h ...interface{}) { | func (r *Route) Delete(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Delete(r.getPattern(pattern), Wrap(middlewares...)) | |||||
} | |||||
func (r *Route) getMiddlewares(h []interface{}) []interface{} { | |||||
middlewares := make([]interface{}, len(r.curMiddlewares), len(r.curMiddlewares)+len(h)) | |||||
copy(middlewares, r.curMiddlewares) | |||||
middlewares = append(middlewares, h...) | |||||
return middlewares | |||||
r.Methods("DELETE", pattern, h) | |||||
} | } | ||||
// Get delegate get method | // Get delegate get method | ||||
func (r *Route) Get(pattern string, h ...interface{}) { | func (r *Route) Get(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Get(r.getPattern(pattern), Wrap(middlewares...)) | |||||
} | |||||
// Options delegate options method | |||||
func (r *Route) Options(pattern string, h ...interface{}) { | |||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Options(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.Methods("GET", pattern, h) | |||||
} | } | ||||
// GetOptions delegate get and options method | // GetOptions delegate get and options method | ||||
func (r *Route) GetOptions(pattern string, h ...interface{}) { | func (r *Route) GetOptions(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Get(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.R.Options(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.Methods("GET,OPTIONS", pattern, h) | |||||
} | } | ||||
// PostOptions delegate post and options method | // PostOptions delegate post and options method | ||||
func (r *Route) PostOptions(pattern string, h ...interface{}) { | func (r *Route) PostOptions(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Post(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.R.Options(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.Methods("POST,OPTIONS", pattern, h) | |||||
} | } | ||||
// Head delegate head method | // Head delegate head method | ||||
func (r *Route) Head(pattern string, h ...interface{}) { | func (r *Route) Head(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Head(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.Methods("HEAD", pattern, h) | |||||
} | } | ||||
// Post delegate post method | // Post delegate post method | ||||
func (r *Route) Post(pattern string, h ...interface{}) { | func (r *Route) Post(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Post(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.Methods("POST", pattern, h) | |||||
} | } | ||||
// Put delegate put method | // Put delegate put method | ||||
func (r *Route) Put(pattern string, h ...interface{}) { | func (r *Route) Put(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Put(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.Methods("PUT", pattern, h) | |||||
} | } | ||||
// Patch delegate patch method | // Patch delegate patch method | ||||
func (r *Route) Patch(pattern string, h ...interface{}) { | func (r *Route) Patch(pattern string, h ...interface{}) { | ||||
middlewares := r.getMiddlewares(h) | |||||
r.R.Patch(r.getPattern(pattern), Wrap(middlewares...)) | |||||
r.Methods("PATCH", pattern, h) | |||||
} | } | ||||
// ServeHTTP implements http.Handler | // ServeHTTP implements http.Handler | ||||
r.R.ServeHTTP(w, req) | r.R.ServeHTTP(w, req) | ||||
} | } | ||||
// NotFound defines a handler to respond whenever a route could | |||||
// not be found. | |||||
// NotFound defines a handler to respond whenever a route could not be found. | |||||
func (r *Route) NotFound(h http.HandlerFunc) { | func (r *Route) NotFound(h http.HandlerFunc) { | ||||
r.R.NotFound(h) | r.R.NotFound(h) | ||||
} | } | ||||
// MethodNotAllowed defines a handler to respond whenever a method is | |||||
// not allowed. | |||||
func (r *Route) MethodNotAllowed(h http.HandlerFunc) { | |||||
r.R.MethodNotAllowed(h) | |||||
} | |||||
// Combo deletegate requests to Combo | |||||
// Combo delegates requests to Combo | |||||
func (r *Route) Combo(pattern string, h ...interface{}) *Combo { | func (r *Route) Combo(pattern string, h ...interface{}) *Combo { | ||||
return &Combo{r, pattern, h} | return &Combo{r, pattern, h} | ||||
} | } | ||||
h []interface{} | h []interface{} | ||||
} | } | ||||
// Get deletegate Get method | |||||
// Get delegates Get method | |||||
func (c *Combo) Get(h ...interface{}) *Combo { | func (c *Combo) Get(h ...interface{}) *Combo { | ||||
c.r.Get(c.pattern, append(c.h, h...)...) | c.r.Get(c.pattern, append(c.h, h...)...) | ||||
return c | return c | ||||
} | } | ||||
// Post deletegate Post method | |||||
// Post delegates Post method | |||||
func (c *Combo) Post(h ...interface{}) *Combo { | func (c *Combo) Post(h ...interface{}) *Combo { | ||||
c.r.Post(c.pattern, append(c.h, h...)...) | c.r.Post(c.pattern, append(c.h, h...)...) | ||||
return c | return c | ||||
} | } | ||||
// Delete deletegate Delete method | |||||
// Delete delegates Delete method | |||||
func (c *Combo) Delete(h ...interface{}) *Combo { | func (c *Combo) Delete(h ...interface{}) *Combo { | ||||
c.r.Delete(c.pattern, append(c.h, h...)...) | c.r.Delete(c.pattern, append(c.h, h...)...) | ||||
return c | return c | ||||
} | } | ||||
// Put deletegate Put method | |||||
// Put delegates Put method | |||||
func (c *Combo) Put(h ...interface{}) *Combo { | func (c *Combo) Put(h ...interface{}) *Combo { | ||||
c.r.Put(c.pattern, append(c.h, h...)...) | c.r.Put(c.pattern, append(c.h, h...)...) | ||||
return c | return c | ||||
} | } | ||||
// Patch deletegate Patch method | |||||
// Patch delegates Patch method | |||||
func (c *Combo) Patch(h ...interface{}) *Combo { | func (c *Combo) Patch(h ...interface{}) *Combo { | ||||
c.r.Patch(c.pattern, append(c.h, h...)...) | c.r.Patch(c.pattern, append(c.h, h...)...) | ||||
return c | return c |
"bytes" | "bytes" | ||||
"net/http" | "net/http" | ||||
"net/http/httptest" | "net/http/httptest" | ||||
"strconv" | |||||
"testing" | "testing" | ||||
chi "github.com/go-chi/chi/v5" | chi "github.com/go-chi/chi/v5" | ||||
recorder := httptest.NewRecorder() | recorder := httptest.NewRecorder() | ||||
recorder.Body = buff | recorder.Body = buff | ||||
var route int | |||||
hit := -1 | |||||
r := NewRoute() | r := NewRoute() | ||||
r.Group("/{username}/{reponame}", func() { | r.Group("/{username}/{reponame}", func() { | ||||
assert.EqualValues(t, "gitea", reponame) | assert.EqualValues(t, "gitea", reponame) | ||||
tp := chi.URLParam(req, "type") | tp := chi.URLParam(req, "type") | ||||
assert.EqualValues(t, "issues", tp) | assert.EqualValues(t, "issues", tp) | ||||
route = 0 | |||||
hit = 0 | |||||
}) | }) | ||||
r.Get("/{type:issues|pulls}/{index}", func(resp http.ResponseWriter, req *http.Request) { | r.Get("/{type:issues|pulls}/{index}", func(resp http.ResponseWriter, req *http.Request) { | ||||
assert.EqualValues(t, "issues", tp) | assert.EqualValues(t, "issues", tp) | ||||
index := chi.URLParam(req, "index") | index := chi.URLParam(req, "index") | ||||
assert.EqualValues(t, "1", index) | assert.EqualValues(t, "1", index) | ||||
route = 1 | |||||
hit = 1 | |||||
}) | }) | ||||
}, func(resp http.ResponseWriter, req *http.Request) { | }, func(resp http.ResponseWriter, req *http.Request) { | ||||
resp.WriteHeader(http.StatusOK) | |||||
if stop, err := strconv.Atoi(req.FormValue("stop")); err == nil { | |||||
hit = stop | |||||
resp.WriteHeader(http.StatusOK) | |||||
} | |||||
}) | }) | ||||
r.Group("/issues/{index}", func() { | r.Group("/issues/{index}", func() { | ||||
assert.EqualValues(t, "gitea", reponame) | assert.EqualValues(t, "gitea", reponame) | ||||
index := chi.URLParam(req, "index") | index := chi.URLParam(req, "index") | ||||
assert.EqualValues(t, "1", index) | assert.EqualValues(t, "1", index) | ||||
route = 2 | |||||
hit = 2 | |||||
}) | }) | ||||
}) | }) | ||||
}) | }) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | assert.EqualValues(t, http.StatusOK, recorder.Code) | ||||
assert.EqualValues(t, 0, route) | |||||
assert.EqualValues(t, 0, hit) | |||||
req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1", nil) | req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1", nil) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | assert.EqualValues(t, http.StatusOK, recorder.Code) | ||||
assert.EqualValues(t, 1, route) | |||||
assert.EqualValues(t, 1, hit) | |||||
req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1?stop=100", nil) | |||||
assert.NoError(t, err) | |||||
r.ServeHTTP(recorder, req) | |||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | |||||
assert.EqualValues(t, 100, hit) | |||||
req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1/view", nil) | req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1/view", nil) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | assert.EqualValues(t, http.StatusOK, recorder.Code) | ||||
assert.EqualValues(t, 2, route) | |||||
assert.EqualValues(t, 2, hit) | |||||
} | } | ||||
func TestRoute3(t *testing.T) { | func TestRoute3(t *testing.T) { | ||||
recorder := httptest.NewRecorder() | recorder := httptest.NewRecorder() | ||||
recorder.Body = buff | recorder.Body = buff | ||||
var route int | |||||
hit := -1 | |||||
m := NewRoute() | m := NewRoute() | ||||
r := NewRoute() | r := NewRoute() | ||||
m.Group("/{username}/{reponame}", func() { | m.Group("/{username}/{reponame}", func() { | ||||
m.Group("/branch_protections", func() { | m.Group("/branch_protections", func() { | ||||
m.Get("", func(resp http.ResponseWriter, req *http.Request) { | m.Get("", func(resp http.ResponseWriter, req *http.Request) { | ||||
route = 0 | |||||
hit = 0 | |||||
}) | }) | ||||
m.Post("", func(resp http.ResponseWriter, req *http.Request) { | m.Post("", func(resp http.ResponseWriter, req *http.Request) { | ||||
route = 1 | |||||
hit = 1 | |||||
}) | }) | ||||
m.Group("/{name}", func() { | m.Group("/{name}", func() { | ||||
m.Get("", func(resp http.ResponseWriter, req *http.Request) { | m.Get("", func(resp http.ResponseWriter, req *http.Request) { | ||||
route = 2 | |||||
hit = 2 | |||||
}) | }) | ||||
m.Patch("", func(resp http.ResponseWriter, req *http.Request) { | m.Patch("", func(resp http.ResponseWriter, req *http.Request) { | ||||
route = 3 | |||||
hit = 3 | |||||
}) | }) | ||||
m.Delete("", func(resp http.ResponseWriter, req *http.Request) { | m.Delete("", func(resp http.ResponseWriter, req *http.Request) { | ||||
route = 4 | |||||
hit = 4 | |||||
}) | }) | ||||
}) | }) | ||||
}) | }) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | assert.EqualValues(t, http.StatusOK, recorder.Code) | ||||
assert.EqualValues(t, 0, route) | |||||
assert.EqualValues(t, 0, hit) | |||||
req, err = http.NewRequest("POST", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) | req, err = http.NewRequest("POST", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code, http.StatusOK) | assert.EqualValues(t, http.StatusOK, recorder.Code, http.StatusOK) | ||||
assert.EqualValues(t, 1, route) | |||||
assert.EqualValues(t, 1, hit) | |||||
req, err = http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) | req, err = http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | assert.EqualValues(t, http.StatusOK, recorder.Code) | ||||
assert.EqualValues(t, 2, route) | |||||
assert.EqualValues(t, 2, hit) | |||||
req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) | req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | assert.EqualValues(t, http.StatusOK, recorder.Code) | ||||
assert.EqualValues(t, 3, route) | |||||
assert.EqualValues(t, 3, hit) | |||||
req, err = http.NewRequest("DELETE", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) | req, err = http.NewRequest("DELETE", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) | ||||
assert.NoError(t, err) | assert.NoError(t, err) | ||||
r.ServeHTTP(recorder, req) | r.ServeHTTP(recorder, req) | ||||
assert.EqualValues(t, http.StatusOK, recorder.Code) | assert.EqualValues(t, http.StatusOK, recorder.Code) | ||||
assert.EqualValues(t, 4, route) | |||||
assert.EqualValues(t, 4, hit) | |||||
} | } |
// Copyright 2021 The Gitea Authors. All rights reserved. | |||||
// SPDX-License-Identifier: MIT | |||||
package web | |||||
import ( | |||||
goctx "context" | |||||
"net/http" | |||||
"strings" | |||||
"code.gitea.io/gitea/modules/context" | |||||
"code.gitea.io/gitea/modules/web/routing" | |||||
) | |||||
// Wrap converts all kinds of routes to standard library one | |||||
func Wrap(handlers ...interface{}) http.HandlerFunc { | |||||
if len(handlers) == 0 { | |||||
panic("No handlers found") | |||||
} | |||||
ourHandlers := make([]wrappedHandlerFunc, 0, len(handlers)) | |||||
for _, handler := range handlers { | |||||
ourHandlers = append(ourHandlers, convertHandler(handler)) | |||||
} | |||||
return wrapInternal(ourHandlers) | |||||
} | |||||
func wrapInternal(handlers []wrappedHandlerFunc) http.HandlerFunc { | |||||
return func(resp http.ResponseWriter, req *http.Request) { | |||||
var defers []func() | |||||
defer func() { | |||||
for i := len(defers) - 1; i >= 0; i-- { | |||||
defers[i]() | |||||
} | |||||
}() | |||||
for i := 0; i < len(handlers); i++ { | |||||
handler := handlers[i] | |||||
others := handlers[i+1:] | |||||
done, deferrable := handler(resp, req, others...) | |||||
if deferrable != nil { | |||||
defers = append(defers, deferrable) | |||||
} | |||||
if done { | |||||
return | |||||
} | |||||
} | |||||
} | |||||
} | |||||
// Middle wrap a context function as a chi middleware | |||||
func Middle(f func(ctx *context.Context)) func(next http.Handler) http.Handler { | |||||
funcInfo := routing.GetFuncInfo(f) | |||||
return func(next http.Handler) http.Handler { | |||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetContext(req) | |||||
f(ctx) | |||||
if ctx.Written() { | |||||
return | |||||
} | |||||
next.ServeHTTP(ctx.Resp, ctx.Req) | |||||
}) | |||||
} | |||||
} | |||||
// MiddleCancel wrap a context function as a chi middleware | |||||
func MiddleCancel(f func(ctx *context.Context) goctx.CancelFunc) func(netx http.Handler) http.Handler { | |||||
funcInfo := routing.GetFuncInfo(f) | |||||
return func(next http.Handler) http.Handler { | |||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetContext(req) | |||||
cancel := f(ctx) | |||||
if cancel != nil { | |||||
defer cancel() | |||||
} | |||||
if ctx.Written() { | |||||
return | |||||
} | |||||
next.ServeHTTP(ctx.Resp, ctx.Req) | |||||
}) | |||||
} | |||||
} | |||||
// MiddleAPI wrap a context function as a chi middleware | |||||
func MiddleAPI(f func(ctx *context.APIContext)) func(next http.Handler) http.Handler { | |||||
funcInfo := routing.GetFuncInfo(f) | |||||
return func(next http.Handler) http.Handler { | |||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetAPIContext(req) | |||||
f(ctx) | |||||
if ctx.Written() { | |||||
return | |||||
} | |||||
next.ServeHTTP(ctx.Resp, ctx.Req) | |||||
}) | |||||
} | |||||
} | |||||
// WrapWithPrefix wraps a provided handler function at a prefix | |||||
func WrapWithPrefix(pathPrefix string, handler http.HandlerFunc, friendlyName ...string) func(next http.Handler) http.Handler { | |||||
funcInfo := routing.GetFuncInfo(handler, friendlyName...) | |||||
return func(next http.Handler) http.Handler { | |||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | |||||
if !strings.HasPrefix(req.URL.Path, pathPrefix) { | |||||
next.ServeHTTP(resp, req) | |||||
return | |||||
} | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
handler(resp, req) | |||||
}) | |||||
} | |||||
} |
// Copyright 2021 The Gitea Authors. All rights reserved. | |||||
// SPDX-License-Identifier: MIT | |||||
package web | |||||
import ( | |||||
goctx "context" | |||||
"fmt" | |||||
"net/http" | |||||
"code.gitea.io/gitea/modules/context" | |||||
"code.gitea.io/gitea/modules/web/routing" | |||||
) | |||||
type wrappedHandlerFunc func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) | |||||
func convertHandler(handler interface{}) wrappedHandlerFunc { | |||||
funcInfo := routing.GetFuncInfo(handler) | |||||
switch t := handler.(type) { | |||||
case http.HandlerFunc: | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
if _, ok := resp.(context.ResponseWriter); !ok { | |||||
resp = context.NewResponse(resp) | |||||
} | |||||
t(resp, req) | |||||
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { | |||||
done = true | |||||
} | |||||
return done, deferrable | |||||
} | |||||
case func(http.ResponseWriter, *http.Request): | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
t(resp, req) | |||||
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { | |||||
done = true | |||||
} | |||||
return done, deferrable | |||||
} | |||||
case func(ctx *context.Context): | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetContext(req) | |||||
t(ctx) | |||||
done = ctx.Written() | |||||
return done, deferrable | |||||
} | |||||
case func(ctx *context.Context) goctx.CancelFunc: | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetContext(req) | |||||
deferrable = t(ctx) | |||||
done = ctx.Written() | |||||
return done, deferrable | |||||
} | |||||
case func(*context.APIContext): | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetAPIContext(req) | |||||
t(ctx) | |||||
done = ctx.Written() | |||||
return done, deferrable | |||||
} | |||||
case func(*context.APIContext) goctx.CancelFunc: | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetAPIContext(req) | |||||
deferrable = t(ctx) | |||||
done = ctx.Written() | |||||
return done, deferrable | |||||
} | |||||
case func(*context.PrivateContext): | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetPrivateContext(req) | |||||
t(ctx) | |||||
done = ctx.Written() | |||||
return done, deferrable | |||||
} | |||||
case func(*context.PrivateContext) goctx.CancelFunc: | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
ctx := context.GetPrivateContext(req) | |||||
deferrable = t(ctx) | |||||
done = ctx.Written() | |||||
return done, deferrable | |||||
} | |||||
case func(http.Handler) http.Handler: | |||||
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) { | |||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) | |||||
if len(others) > 0 { | |||||
next = wrapInternal(others) | |||||
} | |||||
routing.UpdateFuncInfo(req.Context(), funcInfo) | |||||
if _, ok := resp.(context.ResponseWriter); !ok { | |||||
resp = context.NewResponse(resp) | |||||
} | |||||
t(next).ServeHTTP(resp, req) | |||||
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { | |||||
done = true | |||||
} | |||||
return done, deferrable | |||||
} | |||||
default: | |||||
panic(fmt.Sprintf("Unsupported handler type: %#v", t)) | |||||
} | |||||
} |
) | ) | ||||
// Manual mapping of routes because {image} can contain slashes which chi does not support | // Manual mapping of routes because {image} can contain slashes which chi does not support | ||||
r.Route("/*", "HEAD,GET,POST,PUT,PATCH,DELETE", func(ctx *context.Context) { | |||||
r.RouteMethods("/*", "HEAD,GET,POST,PUT,PATCH,DELETE", func(ctx *context.Context) { | |||||
path := ctx.Params("*") | path := ctx.Params("*") | ||||
isHead := ctx.Req.Method == "HEAD" | isHead := ctx.Req.Method == "HEAD" | ||||
isGet := ctx.Req.Method == "GET" | isGet := ctx.Req.Method == "GET" |
} | } | ||||
// bind binding an obj to a func(ctx *context.APIContext) | // bind binding an obj to a func(ctx *context.APIContext) | ||||
func bind[T any](obj T) http.HandlerFunc { | |||||
return web.Wrap(func(ctx *context.APIContext) { | |||||
func bind[T any](_ T) any { | |||||
return func(ctx *context.APIContext) { | |||||
theObj := new(T) // create a new form obj for every request but not use obj directly | theObj := new(T) // create a new form obj for every request but not use obj directly | ||||
errs := binding.Bind(ctx.Req, theObj) | errs := binding.Bind(ctx.Req, theObj) | ||||
if len(errs) > 0 { | if len(errs) > 0 { | ||||
return | return | ||||
} | } | ||||
web.SetForm(ctx, theObj) | web.SetForm(ctx, theObj) | ||||
}) | |||||
} | |||||
} | } | ||||
// The OAuth2 plugin is expected to be executed first, as it must ignore the user id stored | // The OAuth2 plugin is expected to be executed first, as it must ignore the user id stored |
} | } | ||||
} | } | ||||
// Routes registers the install routes | |||||
// Routes registers the installation routes | |||||
func Routes(ctx goctx.Context) *web.Route { | func Routes(ctx goctx.Context) *web.Route { | ||||
r := web.NewRoute() | r := web.NewRoute() | ||||
for _, middle := range common.Middlewares() { | for _, middle := range common.Middlewares() { | ||||
r.Use(middle) | r.Use(middle) | ||||
} | } | ||||
r.Use(web.WrapWithPrefix("/assets/", public.AssetsHandlerFunc("/assets/"), "AssetsHandler")) | |||||
r.Use(web.MiddlewareWithPrefix("/assets/", nil, public.AssetsHandlerFunc("/assets/"))) | |||||
r.Use(session.Sessioner(session.Options{ | r.Use(session.Sessioner(session.Options{ | ||||
Provider: setting.SessionConfig.Provider, | Provider: setting.SessionConfig.Provider, | ||||
r.Get("/post-install", InstallDone) | r.Get("/post-install", InstallDone) | ||||
r.Get("/api/healthz", healthcheck.Check) | r.Get("/api/healthz", healthcheck.Check) | ||||
r.NotFound(web.Wrap(installNotFound)) | |||||
r.NotFound(installNotFound) | |||||
return r | return r | ||||
} | } | ||||
} | } | ||||
// bind binding an obj to a handler | // bind binding an obj to a handler | ||||
func bind[T any](obj T) http.HandlerFunc { | |||||
return web.Wrap(func(ctx *context.PrivateContext) { | |||||
func bind[T any](_ T) any { | |||||
return func(ctx *context.PrivateContext) { | |||||
theObj := new(T) // create a new form obj for every request but not use obj directly | theObj := new(T) // create a new form obj for every request but not use obj directly | ||||
binding.Bind(ctx.Req, theObj) | binding.Bind(ctx.Req, theObj) | ||||
web.SetForm(ctx, theObj) | web.SetForm(ctx, theObj) | ||||
}) | |||||
} | |||||
} | } | ||||
// Routes registers all internal APIs routes to web application. | // Routes registers all internal APIs routes to web application. |
func Routes(ctx gocontext.Context) *web.Route { | func Routes(ctx gocontext.Context) *web.Route { | ||||
routes := web.NewRoute() | routes := web.NewRoute() | ||||
routes.Use(web.WrapWithPrefix("/assets/", web.Wrap(CorsHandler(), public.AssetsHandlerFunc("/assets/")), "AssetsHandler")) | |||||
routes.Use(web.MiddlewareWithPrefix("/assets/", CorsHandler(), public.AssetsHandlerFunc("/assets/"))) | |||||
sessioner := session.Sessioner(session.Options{ | sessioner := session.Sessioner(session.Options{ | ||||
Provider: setting.SessionConfig.Provider, | Provider: setting.SessionConfig.Provider, | ||||
routes.Use(Recovery(ctx)) | routes.Use(Recovery(ctx)) | ||||
// We use r.Route here over r.Use because this prevents requests that are not for avatars having to go through this additional handler | // We use r.Route here over r.Use because this prevents requests that are not for avatars having to go through this additional handler | ||||
routes.Route("/avatars/*", "GET, HEAD", storageHandler(setting.Avatar.Storage, "avatars", storage.Avatars)) | |||||
routes.Route("/repo-avatars/*", "GET, HEAD", storageHandler(setting.RepoAvatar.Storage, "repo-avatars", storage.RepoAvatars)) | |||||
routes.RouteMethods("/avatars/*", "GET, HEAD", storageHandler(setting.Avatar.Storage, "avatars", storage.Avatars)) | |||||
routes.RouteMethods("/repo-avatars/*", "GET, HEAD", storageHandler(setting.RepoAvatar.Storage, "repo-avatars", storage.RepoAvatars)) | |||||
// for health check - doesn't need to be passed through gzip handler | // for health check - doesn't need to be passed through gzip handler | ||||
routes.Head("/", func(w http.ResponseWriter, req *http.Request) { | routes.Head("/", func(w http.ResponseWriter, req *http.Request) { | ||||
if setting.Service.EnableCaptcha { | if setting.Service.EnableCaptcha { | ||||
// The captcha http.Handler should only fire on /captcha/* so we can just mount this on that url | // The captcha http.Handler should only fire on /captcha/* so we can just mount this on that url | ||||
routes.Route("/captcha/*", "GET,HEAD", append(common, captcha.Captchaer(context.GetImageCaptcha()))...) | |||||
routes.RouteMethods("/captcha/*", "GET,HEAD", append(common, captcha.Captchaer(context.GetImageCaptcha()))...) | |||||
} | } | ||||
if setting.HasRobotsTxt { | if setting.HasRobotsTxt { | ||||
m.Post("/delete", org.SecretsDelete) | m.Post("/delete", org.SecretsDelete) | ||||
}) | }) | ||||
m.Route("/delete", "GET,POST", org.SettingsDelete) | |||||
m.RouteMethods("/delete", "GET,POST", org.SettingsDelete) | |||||
m.Group("/packages", func() { | m.Group("/packages", func() { | ||||
m.Get("", org.Packages) | m.Get("", org.Packages) |