You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Copyright 2023 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. package web
  4. import (
  5. goctx "context"
  6. "fmt"
  7. "net/http"
  8. "reflect"
  9. "code.gitea.io/gitea/modules/log"
  10. "code.gitea.io/gitea/modules/web/routing"
  11. "code.gitea.io/gitea/modules/web/types"
  12. )
  13. var responseStatusProviders = map[reflect.Type]func(req *http.Request) types.ResponseStatusProvider{}
  14. func RegisterResponseStatusProvider[T any](fn func(req *http.Request) types.ResponseStatusProvider) {
  15. responseStatusProviders[reflect.TypeOf((*T)(nil)).Elem()] = fn
  16. }
  17. // responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written
  18. type responseWriter struct {
  19. respWriter http.ResponseWriter
  20. status int
  21. }
  22. var _ types.ResponseStatusProvider = (*responseWriter)(nil)
  23. func (r *responseWriter) WrittenStatus() int {
  24. return r.status
  25. }
  26. func (r *responseWriter) Header() http.Header {
  27. return r.respWriter.Header()
  28. }
  29. func (r *responseWriter) Write(bytes []byte) (int, error) {
  30. if r.status == 0 {
  31. r.status = http.StatusOK
  32. }
  33. return r.respWriter.Write(bytes)
  34. }
  35. func (r *responseWriter) WriteHeader(statusCode int) {
  36. r.status = statusCode
  37. r.respWriter.WriteHeader(statusCode)
  38. }
  39. var (
  40. httpReqType = reflect.TypeOf((*http.Request)(nil))
  41. respWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()
  42. cancelFuncType = reflect.TypeOf((*goctx.CancelFunc)(nil)).Elem()
  43. )
  44. // preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup
  45. func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
  46. hasStatusProvider := false
  47. for _, argIn := range argsIn {
  48. if _, hasStatusProvider = argIn.Interface().(types.ResponseStatusProvider); hasStatusProvider {
  49. break
  50. }
  51. }
  52. if !hasStatusProvider {
  53. panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type()))
  54. }
  55. if fn.Type().NumOut() != 0 && fn.Type().NumIn() != 1 {
  56. panic(fmt.Sprintf("handler should have no return value or only one argument, but got %s", fn.Type()))
  57. }
  58. if fn.Type().NumOut() == 1 && fn.Type().Out(0) != cancelFuncType {
  59. panic(fmt.Sprintf("handler should return a cancel function, but got %s", fn.Type()))
  60. }
  61. }
  62. func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value {
  63. defer func() {
  64. if err := recover(); err != nil {
  65. log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err)
  66. panic(err)
  67. }
  68. }()
  69. isPreCheck := req == nil
  70. argsIn := make([]reflect.Value, fn.Type().NumIn())
  71. for i := 0; i < fn.Type().NumIn(); i++ {
  72. argTyp := fn.Type().In(i)
  73. switch argTyp {
  74. case respWriterType:
  75. argsIn[i] = reflect.ValueOf(resp)
  76. case httpReqType:
  77. argsIn[i] = reflect.ValueOf(req)
  78. default:
  79. if argFn, ok := responseStatusProviders[argTyp]; ok {
  80. if isPreCheck {
  81. argsIn[i] = reflect.ValueOf(&responseWriter{})
  82. } else {
  83. argsIn[i] = reflect.ValueOf(argFn(req))
  84. }
  85. } else {
  86. panic(fmt.Sprintf("unsupported argument type: %s", argTyp))
  87. }
  88. }
  89. }
  90. return argsIn
  91. }
  92. func handleResponse(fn reflect.Value, ret []reflect.Value) goctx.CancelFunc {
  93. if len(ret) == 1 {
  94. if cancelFunc, ok := ret[0].Interface().(goctx.CancelFunc); ok {
  95. return cancelFunc
  96. }
  97. panic(fmt.Sprintf("unsupported return type: %s", ret[0].Type()))
  98. } else if len(ret) > 1 {
  99. panic(fmt.Sprintf("unsupported return values: %s", fn.Type()))
  100. }
  101. return nil
  102. }
  103. func hasResponseBeenWritten(argsIn []reflect.Value) bool {
  104. for _, argIn := range argsIn {
  105. if statusProvider, ok := argIn.Interface().(types.ResponseStatusProvider); ok {
  106. if statusProvider.WrittenStatus() != 0 {
  107. return true
  108. }
  109. }
  110. }
  111. return false
  112. }
  113. // toHandlerProvider converts a handler to a handler provider
  114. // A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
  115. func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
  116. funcInfo := routing.GetFuncInfo(handler)
  117. fn := reflect.ValueOf(handler)
  118. if fn.Type().Kind() != reflect.Func {
  119. panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
  120. }
  121. if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
  122. return func(next http.Handler) http.Handler {
  123. h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
  124. return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
  125. routing.UpdateFuncInfo(req.Context(), funcInfo)
  126. h.ServeHTTP(resp, req)
  127. })
  128. }
  129. }
  130. provider := func(next http.Handler) http.Handler {
  131. return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) {
  132. // wrap the response writer to check whether the response has been written
  133. resp := respOrig
  134. if _, ok := resp.(types.ResponseStatusProvider); !ok {
  135. resp = &responseWriter{respWriter: resp}
  136. }
  137. // prepare the arguments for the handler and do pre-check
  138. argsIn := prepareHandleArgsIn(resp, req, fn, funcInfo)
  139. if req == nil {
  140. preCheckHandler(fn, argsIn)
  141. return // it's doing pre-check, just return
  142. }
  143. routing.UpdateFuncInfo(req.Context(), funcInfo)
  144. ret := fn.Call(argsIn)
  145. // handle the return value, and defer the cancel function if there is one
  146. cancelFunc := handleResponse(fn, ret)
  147. if cancelFunc != nil {
  148. defer cancelFunc()
  149. }
  150. // if the response has not been written, call the next handler
  151. if next != nil && !hasResponseBeenWritten(argsIn) {
  152. next.ServeHTTP(resp, req)
  153. }
  154. })
  155. }
  156. provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported
  157. return provider
  158. }