Просмотр исходного кода

Refactor web route (#24080)

The old code is unnecessarily complex, and has many misuses.

Old code "wraps" a lot, wrap wrap wrap, it's difficult to understand
which kind of handler is used.

The new code uses a general approach, we do not need to write all kinds
of handlers into the "wrapper", do not need to wrap them again and
again.

New code, there are only 2 concepts:

1. HandlerProvider: `func (h any) (handlerProvider func (next)
http.Handler)`, it can be used as middleware
2. Use HandlerProvider to get the final HandlerFunc, and use it for
`r.Get()`


And we can decouple the route package from context package (see the
TODO).

# FAQ

## Is `reflect` safe?

Yes, all handlers are checked during startup, see the `preCheckHandler`
comment. If any handler is wrong, developers could know it in the first
time.

## Does `reflect` affect performance?

No. https://github.com/go-gitea/gitea/pull/24080#discussion_r1164825901

1. This reflect code only runs for each web handler call, handler is far
more slower: 10ms-50ms
2. The reflect is pretty fast (comparing to other code): 0.000265ms
3. XORM has more reflect operations already
tags/v1.20.0-rc0
wxiaoguang 1 год назад
Родитель
Сommit
b9a97ccd0e
Аккаунт пользователя с таким Email не найден

+ 200
- 0
modules/web/handler.go Просмотреть файл

// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package web

import (
goctx "context"
"fmt"
"net/http"
"reflect"
"strings"

"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/web/routing"
)

// ResponseStatusProvider is an interface to check whether the response has been written by the handler
type ResponseStatusProvider interface {
Written() bool
}

// TODO: decouple this from the context package, let the context package register these providers
var argTypeProvider = map[reflect.Type]func(req *http.Request) ResponseStatusProvider{
reflect.TypeOf(&context.APIContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetAPIContext(req) },
reflect.TypeOf(&context.Context{}): func(req *http.Request) ResponseStatusProvider { return context.GetContext(req) },
reflect.TypeOf(&context.PrivateContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetPrivateContext(req) },
}

// responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written
type responseWriter struct {
respWriter http.ResponseWriter
status int
}

var _ ResponseStatusProvider = (*responseWriter)(nil)

func (r *responseWriter) Written() bool {
return r.status > 0
}

func (r *responseWriter) Header() http.Header {
return r.respWriter.Header()
}

func (r *responseWriter) Write(bytes []byte) (int, error) {
if r.status == 0 {
r.status = http.StatusOK
}
return r.respWriter.Write(bytes)
}

func (r *responseWriter) WriteHeader(statusCode int) {
r.status = statusCode
r.respWriter.WriteHeader(statusCode)
}

var (
httpReqType = reflect.TypeOf((*http.Request)(nil))
respWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()
cancelFuncType = reflect.TypeOf((*goctx.CancelFunc)(nil)).Elem()
)

// preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup
func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
hasStatusProvider := false
for _, argIn := range argsIn {
if _, hasStatusProvider = argIn.Interface().(ResponseStatusProvider); hasStatusProvider {
break
}
}
if !hasStatusProvider {
panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type()))
}
if fn.Type().NumOut() != 0 && fn.Type().NumIn() != 1 {
panic(fmt.Sprintf("handler should have no return value or only one argument, but got %s", fn.Type()))
}
if fn.Type().NumOut() == 1 && fn.Type().Out(0) != cancelFuncType {
panic(fmt.Sprintf("handler should return a cancel function, but got %s", fn.Type()))
}
}

func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value) []reflect.Value {
isPreCheck := req == nil

argsIn := make([]reflect.Value, fn.Type().NumIn())
for i := 0; i < fn.Type().NumIn(); i++ {
argTyp := fn.Type().In(i)
switch argTyp {
case respWriterType:
argsIn[i] = reflect.ValueOf(resp)
case httpReqType:
argsIn[i] = reflect.ValueOf(req)
default:
if argFn, ok := argTypeProvider[argTyp]; ok {
if isPreCheck {
argsIn[i] = reflect.ValueOf(&responseWriter{})
} else {
argsIn[i] = reflect.ValueOf(argFn(req))
}
} else {
panic(fmt.Sprintf("unsupported argument type: %s", argTyp))
}
}
}
return argsIn
}

func handleResponse(fn reflect.Value, ret []reflect.Value) goctx.CancelFunc {
if len(ret) == 1 {
if cancelFunc, ok := ret[0].Interface().(goctx.CancelFunc); ok {
return cancelFunc
}
panic(fmt.Sprintf("unsupported return type: %s", ret[0].Type()))
} else if len(ret) > 1 {
panic(fmt.Sprintf("unsupported return values: %s", fn.Type()))
}
return nil
}

func hasResponseBeenWritten(argsIn []reflect.Value) bool {
for _, argIn := range argsIn {
if statusProvider, ok := argIn.Interface().(ResponseStatusProvider); ok {
if statusProvider.Written() {
return true
}
}
}
return false
}

// toHandlerProvider converts a handler to a handler provider
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
return hp
}

funcInfo := routing.GetFuncInfo(handler)
fn := reflect.ValueOf(handler)
if fn.Type().Kind() != reflect.Func {
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
}

provider := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) {
// wrap the response writer to check whether the response has been written
resp := respOrig
if _, ok := resp.(ResponseStatusProvider); !ok {
resp = &responseWriter{respWriter: resp}
}

// prepare the arguments for the handler and do pre-check
argsIn := prepareHandleArgsIn(resp, req, fn)
if req == nil {
preCheckHandler(fn, argsIn)
return // it's doing pre-check, just return
}

routing.UpdateFuncInfo(req.Context(), funcInfo)
ret := fn.Call(argsIn)

// handle the return value, and defer the cancel function if there is one
cancelFunc := handleResponse(fn, ret)
if cancelFunc != nil {
defer cancelFunc()
}

// if the response has not been written, call the next handler
if next != nil && !hasResponseBeenWritten(argsIn) {
next.ServeHTTP(resp, req)
}
})
}

provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported
return provider
}

// MiddlewareWithPrefix wraps a handler function at a prefix, and make it as a middleware
// TODO: this design is incorrect, the asset handler should not be a middleware
func MiddlewareWithPrefix(pathPrefix string, middleware func(handler http.Handler) http.Handler, handlerFunc http.HandlerFunc) func(next http.Handler) http.Handler {
funcInfo := routing.GetFuncInfo(handlerFunc)
handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
handlerFunc(resp, req)
})
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
if !strings.HasPrefix(req.URL.Path, pathPrefix) {
next.ServeHTTP(resp, req)
return
}
if middleware != nil {
middleware(handler).ServeHTTP(resp, req)
} else {
handler.ServeHTTP(resp, req)
}
})
}
}

+ 53
- 73
modules/web/route.go Просмотреть файл

package web package web


import ( import (
goctx "context"
"fmt"
"net/http" "net/http"
"strings" "strings"


) )


// Bind binding an obj to a handler // Bind binding an obj to a handler
func Bind[T any](obj T) http.HandlerFunc {
return Wrap(func(ctx *context.Context) {
func Bind[T any](_ T) any {
return func(ctx *context.Context) {
theObj := new(T) // create a new form obj for every request but not use obj directly theObj := new(T) // create a new form obj for every request but not use obj directly
binding.Bind(ctx.Req, theObj) binding.Bind(ctx.Req, theObj)
SetForm(ctx, theObj) SetForm(ctx, theObj)
middleware.AssignForm(theObj, ctx.Data) middleware.AssignForm(theObj, ctx.Data)
})
}
} }


// SetForm set the form object // SetForm set the form object
// Use supports two middlewares // Use supports two middlewares
func (r *Route) Use(middlewares ...interface{}) { func (r *Route) Use(middlewares ...interface{}) {
if r.curGroupPrefix != "" { if r.curGroupPrefix != "" {
// FIXME: this behavior is incorrect, should use "With" instead
r.curMiddlewares = append(r.curMiddlewares, middlewares...) r.curMiddlewares = append(r.curMiddlewares, middlewares...)
} else { } else {
for _, middle := range middlewares {
switch t := middle.(type) {
case func(http.Handler) http.Handler:
r.R.Use(t)
case func(*context.Context):
r.R.Use(Middle(t))
case func(*context.Context) goctx.CancelFunc:
r.R.Use(MiddleCancel(t))
case func(*context.APIContext):
r.R.Use(MiddleAPI(t))
default:
panic(fmt.Sprintf("Unsupported middleware type: %#v", t))
}
// FIXME: another misuse, the "Use" with empty middlewares is called after "Mount"
for _, m := range middlewares {
r.R.Use(toHandlerProvider(m))
} }
} }
} }
return strings.TrimSuffix(newPattern, "/") return strings.TrimSuffix(newPattern, "/")
} }


func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) {
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h))
for _, m := range r.curMiddlewares {
handlerProviders = append(handlerProviders, toHandlerProvider(m))
}
for _, m := range h {
handlerProviders = append(handlerProviders, toHandlerProvider(m))
}
middlewares := handlerProviders[:len(handlerProviders)-1]
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
return middlewares, handlerFunc
}

func (r *Route) Methods(method, pattern string, h []any) {
middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h)
fullPattern := r.getPattern(pattern)
if strings.Contains(method, ",") {
methods := strings.Split(method, ",")
for _, method := range methods {
r.R.With(middlewares...).Method(strings.TrimSpace(method), fullPattern, handlerFunc)
}
} else {
r.R.With(middlewares...).Method(method, fullPattern, handlerFunc)
}
}

// Mount attaches another Route along ./pattern/* // Mount attaches another Route along ./pattern/*
func (r *Route) Mount(pattern string, subR *Route) { func (r *Route) Mount(pattern string, subR *Route) {
middlewares := make([]interface{}, len(r.curMiddlewares)) middlewares := make([]interface{}, len(r.curMiddlewares))


// Any delegate requests for all methods // Any delegate requests for all methods
func (r *Route) Any(pattern string, h ...interface{}) { func (r *Route) Any(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.HandleFunc(r.getPattern(pattern), Wrap(middlewares...))
middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h)
r.R.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc)
} }


// Route delegate special methods
func (r *Route) Route(pattern, methods string, h ...interface{}) {
p := r.getPattern(pattern)
ms := strings.Split(methods, ",")
middlewares := r.getMiddlewares(h)
for _, method := range ms {
r.R.MethodFunc(strings.TrimSpace(method), p, Wrap(middlewares...))
}
// RouteMethods delegate special methods, it is an alias of "Methods", while the "pattern" is the first parameter
func (r *Route) RouteMethods(pattern, methods string, h ...interface{}) {
r.Methods(methods, pattern, h)
} }


// Delete delegate delete method // Delete delegate delete method
func (r *Route) Delete(pattern string, h ...interface{}) { func (r *Route) Delete(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Delete(r.getPattern(pattern), Wrap(middlewares...))
}

func (r *Route) getMiddlewares(h []interface{}) []interface{} {
middlewares := make([]interface{}, len(r.curMiddlewares), len(r.curMiddlewares)+len(h))
copy(middlewares, r.curMiddlewares)
middlewares = append(middlewares, h...)
return middlewares
r.Methods("DELETE", pattern, h)
} }


// Get delegate get method // Get delegate get method
func (r *Route) Get(pattern string, h ...interface{}) { func (r *Route) Get(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Get(r.getPattern(pattern), Wrap(middlewares...))
}

// Options delegate options method
func (r *Route) Options(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Options(r.getPattern(pattern), Wrap(middlewares...))
r.Methods("GET", pattern, h)
} }


// GetOptions delegate get and options method // GetOptions delegate get and options method
func (r *Route) GetOptions(pattern string, h ...interface{}) { func (r *Route) GetOptions(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Get(r.getPattern(pattern), Wrap(middlewares...))
r.R.Options(r.getPattern(pattern), Wrap(middlewares...))
r.Methods("GET,OPTIONS", pattern, h)
} }


// PostOptions delegate post and options method // PostOptions delegate post and options method
func (r *Route) PostOptions(pattern string, h ...interface{}) { func (r *Route) PostOptions(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Post(r.getPattern(pattern), Wrap(middlewares...))
r.R.Options(r.getPattern(pattern), Wrap(middlewares...))
r.Methods("POST,OPTIONS", pattern, h)
} }


// Head delegate head method // Head delegate head method
func (r *Route) Head(pattern string, h ...interface{}) { func (r *Route) Head(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Head(r.getPattern(pattern), Wrap(middlewares...))
r.Methods("HEAD", pattern, h)
} }


// Post delegate post method // Post delegate post method
func (r *Route) Post(pattern string, h ...interface{}) { func (r *Route) Post(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Post(r.getPattern(pattern), Wrap(middlewares...))
r.Methods("POST", pattern, h)
} }


// Put delegate put method // Put delegate put method
func (r *Route) Put(pattern string, h ...interface{}) { func (r *Route) Put(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Put(r.getPattern(pattern), Wrap(middlewares...))
r.Methods("PUT", pattern, h)
} }


// Patch delegate patch method // Patch delegate patch method
func (r *Route) Patch(pattern string, h ...interface{}) { func (r *Route) Patch(pattern string, h ...interface{}) {
middlewares := r.getMiddlewares(h)
r.R.Patch(r.getPattern(pattern), Wrap(middlewares...))
r.Methods("PATCH", pattern, h)
} }


// ServeHTTP implements http.Handler // ServeHTTP implements http.Handler
r.R.ServeHTTP(w, req) r.R.ServeHTTP(w, req)
} }


// NotFound defines a handler to respond whenever a route could
// not be found.
// NotFound defines a handler to respond whenever a route could not be found.
func (r *Route) NotFound(h http.HandlerFunc) { func (r *Route) NotFound(h http.HandlerFunc) {
r.R.NotFound(h) r.R.NotFound(h)
} }


// MethodNotAllowed defines a handler to respond whenever a method is
// not allowed.
func (r *Route) MethodNotAllowed(h http.HandlerFunc) {
r.R.MethodNotAllowed(h)
}

// Combo deletegate requests to Combo
// Combo delegates requests to Combo
func (r *Route) Combo(pattern string, h ...interface{}) *Combo { func (r *Route) Combo(pattern string, h ...interface{}) *Combo {
return &Combo{r, pattern, h} return &Combo{r, pattern, h}
} }
h []interface{} h []interface{}
} }


// Get deletegate Get method
// Get delegates Get method
func (c *Combo) Get(h ...interface{}) *Combo { func (c *Combo) Get(h ...interface{}) *Combo {
c.r.Get(c.pattern, append(c.h, h...)...) c.r.Get(c.pattern, append(c.h, h...)...)
return c return c
} }


// Post deletegate Post method
// Post delegates Post method
func (c *Combo) Post(h ...interface{}) *Combo { func (c *Combo) Post(h ...interface{}) *Combo {
c.r.Post(c.pattern, append(c.h, h...)...) c.r.Post(c.pattern, append(c.h, h...)...)
return c return c
} }


// Delete deletegate Delete method
// Delete delegates Delete method
func (c *Combo) Delete(h ...interface{}) *Combo { func (c *Combo) Delete(h ...interface{}) *Combo {
c.r.Delete(c.pattern, append(c.h, h...)...) c.r.Delete(c.pattern, append(c.h, h...)...)
return c return c
} }


// Put deletegate Put method
// Put delegates Put method
func (c *Combo) Put(h ...interface{}) *Combo { func (c *Combo) Put(h ...interface{}) *Combo {
c.r.Put(c.pattern, append(c.h, h...)...) c.r.Put(c.pattern, append(c.h, h...)...)
return c return c
} }


// Patch deletegate Patch method
// Patch delegates Patch method
func (c *Combo) Patch(h ...interface{}) *Combo { func (c *Combo) Patch(h ...interface{}) *Combo {
c.r.Patch(c.pattern, append(c.h, h...)...) c.r.Patch(c.pattern, append(c.h, h...)...)
return c return c

+ 29
- 19
modules/web/route_test.go Просмотреть файл

"bytes" "bytes"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"testing" "testing"


chi "github.com/go-chi/chi/v5" chi "github.com/go-chi/chi/v5"
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
recorder.Body = buff recorder.Body = buff


var route int
hit := -1


r := NewRoute() r := NewRoute()
r.Group("/{username}/{reponame}", func() { r.Group("/{username}/{reponame}", func() {
assert.EqualValues(t, "gitea", reponame) assert.EqualValues(t, "gitea", reponame)
tp := chi.URLParam(req, "type") tp := chi.URLParam(req, "type")
assert.EqualValues(t, "issues", tp) assert.EqualValues(t, "issues", tp)
route = 0
hit = 0
}) })


r.Get("/{type:issues|pulls}/{index}", func(resp http.ResponseWriter, req *http.Request) { r.Get("/{type:issues|pulls}/{index}", func(resp http.ResponseWriter, req *http.Request) {
assert.EqualValues(t, "issues", tp) assert.EqualValues(t, "issues", tp)
index := chi.URLParam(req, "index") index := chi.URLParam(req, "index")
assert.EqualValues(t, "1", index) assert.EqualValues(t, "1", index)
route = 1
hit = 1
}) })
}, func(resp http.ResponseWriter, req *http.Request) { }, func(resp http.ResponseWriter, req *http.Request) {
resp.WriteHeader(http.StatusOK)
if stop, err := strconv.Atoi(req.FormValue("stop")); err == nil {
hit = stop
resp.WriteHeader(http.StatusOK)
}
}) })


r.Group("/issues/{index}", func() { r.Group("/issues/{index}", func() {
assert.EqualValues(t, "gitea", reponame) assert.EqualValues(t, "gitea", reponame)
index := chi.URLParam(req, "index") index := chi.URLParam(req, "index")
assert.EqualValues(t, "1", index) assert.EqualValues(t, "1", index)
route = 2
hit = 2
}) })
}) })
}) })
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code) assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 0, route)
assert.EqualValues(t, 0, hit)


req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1", nil) req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1", nil)
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code) assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 1, route)
assert.EqualValues(t, 1, hit)

req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1?stop=100", nil)
assert.NoError(t, err)
r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 100, hit)


req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1/view", nil) req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1/view", nil)
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code) assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 2, route)
assert.EqualValues(t, 2, hit)
} }


func TestRoute3(t *testing.T) { func TestRoute3(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
recorder.Body = buff recorder.Body = buff


var route int
hit := -1


m := NewRoute() m := NewRoute()
r := NewRoute() r := NewRoute()
m.Group("/{username}/{reponame}", func() { m.Group("/{username}/{reponame}", func() {
m.Group("/branch_protections", func() { m.Group("/branch_protections", func() {
m.Get("", func(resp http.ResponseWriter, req *http.Request) { m.Get("", func(resp http.ResponseWriter, req *http.Request) {
route = 0
hit = 0
}) })
m.Post("", func(resp http.ResponseWriter, req *http.Request) { m.Post("", func(resp http.ResponseWriter, req *http.Request) {
route = 1
hit = 1
}) })
m.Group("/{name}", func() { m.Group("/{name}", func() {
m.Get("", func(resp http.ResponseWriter, req *http.Request) { m.Get("", func(resp http.ResponseWriter, req *http.Request) {
route = 2
hit = 2
}) })
m.Patch("", func(resp http.ResponseWriter, req *http.Request) { m.Patch("", func(resp http.ResponseWriter, req *http.Request) {
route = 3
hit = 3
}) })
m.Delete("", func(resp http.ResponseWriter, req *http.Request) { m.Delete("", func(resp http.ResponseWriter, req *http.Request) {
route = 4
hit = 4
}) })
}) })
}) })
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code) assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 0, route)
assert.EqualValues(t, 0, hit)


req, err = http.NewRequest("POST", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) req, err = http.NewRequest("POST", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil)
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code, http.StatusOK) assert.EqualValues(t, http.StatusOK, recorder.Code, http.StatusOK)
assert.EqualValues(t, 1, route)
assert.EqualValues(t, 1, hit)


req, err = http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) req, err = http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil)
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code) assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 2, route)
assert.EqualValues(t, 2, hit)


req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil)
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code) assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 3, route)
assert.EqualValues(t, 3, hit)


req, err = http.NewRequest("DELETE", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) req, err = http.NewRequest("DELETE", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil)
assert.NoError(t, err) assert.NoError(t, err)
r.ServeHTTP(recorder, req) r.ServeHTTP(recorder, req)
assert.EqualValues(t, http.StatusOK, recorder.Code) assert.EqualValues(t, http.StatusOK, recorder.Code)
assert.EqualValues(t, 4, route)
assert.EqualValues(t, 4, hit)
} }

+ 0
- 116
modules/web/wrap.go Просмотреть файл

// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package web

import (
goctx "context"
"net/http"
"strings"

"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/web/routing"
)

// Wrap converts all kinds of routes to standard library one
func Wrap(handlers ...interface{}) http.HandlerFunc {
if len(handlers) == 0 {
panic("No handlers found")
}

ourHandlers := make([]wrappedHandlerFunc, 0, len(handlers))

for _, handler := range handlers {
ourHandlers = append(ourHandlers, convertHandler(handler))
}
return wrapInternal(ourHandlers)
}

func wrapInternal(handlers []wrappedHandlerFunc) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
var defers []func()
defer func() {
for i := len(defers) - 1; i >= 0; i-- {
defers[i]()
}
}()
for i := 0; i < len(handlers); i++ {
handler := handlers[i]
others := handlers[i+1:]
done, deferrable := handler(resp, req, others...)
if deferrable != nil {
defers = append(defers, deferrable)
}
if done {
return
}
}
}
}

// Middle wrap a context function as a chi middleware
func Middle(f func(ctx *context.Context)) func(next http.Handler) http.Handler {
funcInfo := routing.GetFuncInfo(f)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetContext(req)
f(ctx)
if ctx.Written() {
return
}
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}

// MiddleCancel wrap a context function as a chi middleware
func MiddleCancel(f func(ctx *context.Context) goctx.CancelFunc) func(netx http.Handler) http.Handler {
funcInfo := routing.GetFuncInfo(f)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetContext(req)
cancel := f(ctx)
if cancel != nil {
defer cancel()
}
if ctx.Written() {
return
}
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}

// MiddleAPI wrap a context function as a chi middleware
func MiddleAPI(f func(ctx *context.APIContext)) func(next http.Handler) http.Handler {
funcInfo := routing.GetFuncInfo(f)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetAPIContext(req)
f(ctx)
if ctx.Written() {
return
}
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}

// WrapWithPrefix wraps a provided handler function at a prefix
func WrapWithPrefix(pathPrefix string, handler http.HandlerFunc, friendlyName ...string) func(next http.Handler) http.Handler {
funcInfo := routing.GetFuncInfo(handler, friendlyName...)

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
if !strings.HasPrefix(req.URL.Path, pathPrefix) {
next.ServeHTTP(resp, req)
return
}
routing.UpdateFuncInfo(req.Context(), funcInfo)
handler(resp, req)
})
}
}

+ 0
- 109
modules/web/wrap_convert.go Просмотреть файл

// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package web

import (
goctx "context"
"fmt"
"net/http"

"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/web/routing"
)

type wrappedHandlerFunc func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func())

func convertHandler(handler interface{}) wrappedHandlerFunc {
funcInfo := routing.GetFuncInfo(handler)
switch t := handler.(type) {
case http.HandlerFunc:
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
if _, ok := resp.(context.ResponseWriter); !ok {
resp = context.NewResponse(resp)
}
t(resp, req)
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 {
done = true
}
return done, deferrable
}
case func(http.ResponseWriter, *http.Request):
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
t(resp, req)
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 {
done = true
}
return done, deferrable
}

case func(ctx *context.Context):
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetContext(req)
t(ctx)
done = ctx.Written()
return done, deferrable
}
case func(ctx *context.Context) goctx.CancelFunc:
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetContext(req)
deferrable = t(ctx)
done = ctx.Written()
return done, deferrable
}
case func(*context.APIContext):
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetAPIContext(req)
t(ctx)
done = ctx.Written()
return done, deferrable
}
case func(*context.APIContext) goctx.CancelFunc:
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetAPIContext(req)
deferrable = t(ctx)
done = ctx.Written()
return done, deferrable
}
case func(*context.PrivateContext):
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetPrivateContext(req)
t(ctx)
done = ctx.Written()
return done, deferrable
}
case func(*context.PrivateContext) goctx.CancelFunc:
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
routing.UpdateFuncInfo(req.Context(), funcInfo)
ctx := context.GetPrivateContext(req)
deferrable = t(ctx)
done = ctx.Written()
return done, deferrable
}
case func(http.Handler) http.Handler:
return func(resp http.ResponseWriter, req *http.Request, others ...wrappedHandlerFunc) (done bool, deferrable func()) {
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
if len(others) > 0 {
next = wrapInternal(others)
}
routing.UpdateFuncInfo(req.Context(), funcInfo)
if _, ok := resp.(context.ResponseWriter); !ok {
resp = context.NewResponse(resp)
}
t(next).ServeHTTP(resp, req)
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 {
done = true
}
return done, deferrable
}
default:
panic(fmt.Sprintf("Unsupported handler type: %#v", t))
}
}

+ 1
- 1
routers/api/packages/api.go Просмотреть файл

) )


// Manual mapping of routes because {image} can contain slashes which chi does not support // Manual mapping of routes because {image} can contain slashes which chi does not support
r.Route("/*", "HEAD,GET,POST,PUT,PATCH,DELETE", func(ctx *context.Context) {
r.RouteMethods("/*", "HEAD,GET,POST,PUT,PATCH,DELETE", func(ctx *context.Context) {
path := ctx.Params("*") path := ctx.Params("*")
isHead := ctx.Req.Method == "HEAD" isHead := ctx.Req.Method == "HEAD"
isGet := ctx.Req.Method == "GET" isGet := ctx.Req.Method == "GET"

+ 3
- 3
routers/api/v1/api.go Просмотреть файл

} }


// bind binding an obj to a func(ctx *context.APIContext) // bind binding an obj to a func(ctx *context.APIContext)
func bind[T any](obj T) http.HandlerFunc {
return web.Wrap(func(ctx *context.APIContext) {
func bind[T any](_ T) any {
return func(ctx *context.APIContext) {
theObj := new(T) // create a new form obj for every request but not use obj directly theObj := new(T) // create a new form obj for every request but not use obj directly
errs := binding.Bind(ctx.Req, theObj) errs := binding.Bind(ctx.Req, theObj)
if len(errs) > 0 { if len(errs) > 0 {
return return
} }
web.SetForm(ctx, theObj) web.SetForm(ctx, theObj)
})
}
} }


// The OAuth2 plugin is expected to be executed first, as it must ignore the user id stored // The OAuth2 plugin is expected to be executed first, as it must ignore the user id stored

+ 3
- 3
routers/install/routes.go Просмотреть файл

} }
} }


// Routes registers the install routes
// Routes registers the installation routes
func Routes(ctx goctx.Context) *web.Route { func Routes(ctx goctx.Context) *web.Route {
r := web.NewRoute() r := web.NewRoute()
for _, middle := range common.Middlewares() { for _, middle := range common.Middlewares() {
r.Use(middle) r.Use(middle)
} }


r.Use(web.WrapWithPrefix("/assets/", public.AssetsHandlerFunc("/assets/"), "AssetsHandler"))
r.Use(web.MiddlewareWithPrefix("/assets/", nil, public.AssetsHandlerFunc("/assets/")))


r.Use(session.Sessioner(session.Options{ r.Use(session.Sessioner(session.Options{
Provider: setting.SessionConfig.Provider, Provider: setting.SessionConfig.Provider,
r.Get("/post-install", InstallDone) r.Get("/post-install", InstallDone)
r.Get("/api/healthz", healthcheck.Check) r.Get("/api/healthz", healthcheck.Check)


r.NotFound(web.Wrap(installNotFound))
r.NotFound(installNotFound)
return r return r
} }



+ 3
- 3
routers/private/internal.go Просмотреть файл

} }


// bind binding an obj to a handler // bind binding an obj to a handler
func bind[T any](obj T) http.HandlerFunc {
return web.Wrap(func(ctx *context.PrivateContext) {
func bind[T any](_ T) any {
return func(ctx *context.PrivateContext) {
theObj := new(T) // create a new form obj for every request but not use obj directly theObj := new(T) // create a new form obj for every request but not use obj directly
binding.Bind(ctx.Req, theObj) binding.Bind(ctx.Req, theObj)
web.SetForm(ctx, theObj) web.SetForm(ctx, theObj)
})
}
} }


// Routes registers all internal APIs routes to web application. // Routes registers all internal APIs routes to web application.

+ 5
- 5
routers/web/web.go Просмотреть файл

func Routes(ctx gocontext.Context) *web.Route { func Routes(ctx gocontext.Context) *web.Route {
routes := web.NewRoute() routes := web.NewRoute()


routes.Use(web.WrapWithPrefix("/assets/", web.Wrap(CorsHandler(), public.AssetsHandlerFunc("/assets/")), "AssetsHandler"))
routes.Use(web.MiddlewareWithPrefix("/assets/", CorsHandler(), public.AssetsHandlerFunc("/assets/")))


sessioner := session.Sessioner(session.Options{ sessioner := session.Sessioner(session.Options{
Provider: setting.SessionConfig.Provider, Provider: setting.SessionConfig.Provider,
routes.Use(Recovery(ctx)) routes.Use(Recovery(ctx))


// We use r.Route here over r.Use because this prevents requests that are not for avatars having to go through this additional handler // We use r.Route here over r.Use because this prevents requests that are not for avatars having to go through this additional handler
routes.Route("/avatars/*", "GET, HEAD", storageHandler(setting.Avatar.Storage, "avatars", storage.Avatars))
routes.Route("/repo-avatars/*", "GET, HEAD", storageHandler(setting.RepoAvatar.Storage, "repo-avatars", storage.RepoAvatars))
routes.RouteMethods("/avatars/*", "GET, HEAD", storageHandler(setting.Avatar.Storage, "avatars", storage.Avatars))
routes.RouteMethods("/repo-avatars/*", "GET, HEAD", storageHandler(setting.RepoAvatar.Storage, "repo-avatars", storage.RepoAvatars))


// for health check - doesn't need to be passed through gzip handler // for health check - doesn't need to be passed through gzip handler
routes.Head("/", func(w http.ResponseWriter, req *http.Request) { routes.Head("/", func(w http.ResponseWriter, req *http.Request) {


if setting.Service.EnableCaptcha { if setting.Service.EnableCaptcha {
// The captcha http.Handler should only fire on /captcha/* so we can just mount this on that url // The captcha http.Handler should only fire on /captcha/* so we can just mount this on that url
routes.Route("/captcha/*", "GET,HEAD", append(common, captcha.Captchaer(context.GetImageCaptcha()))...)
routes.RouteMethods("/captcha/*", "GET,HEAD", append(common, captcha.Captchaer(context.GetImageCaptcha()))...)
} }


if setting.HasRobotsTxt { if setting.HasRobotsTxt {
m.Post("/delete", org.SecretsDelete) m.Post("/delete", org.SecretsDelete)
}) })


m.Route("/delete", "GET,POST", org.SettingsDelete)
m.RouteMethods("/delete", "GET,POST", org.SettingsDelete)


m.Group("/packages", func() { m.Group("/packages", func() {
m.Get("", org.Packages) m.Get("", org.Packages)

Загрузка…
Отмена
Сохранить