You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.go 8.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. // This code is highly inspired by endless go
  4. package graceful
  5. import (
  6. "crypto/tls"
  7. "net"
  8. "os"
  9. "strings"
  10. "sync"
  11. "sync/atomic"
  12. "syscall"
  13. "time"
  14. "code.gitea.io/gitea/modules/log"
  15. "code.gitea.io/gitea/modules/proxyprotocol"
  16. "code.gitea.io/gitea/modules/setting"
  17. )
  18. var (
  19. // DefaultReadTimeOut default read timeout
  20. DefaultReadTimeOut time.Duration
  21. // DefaultWriteTimeOut default write timeout
  22. DefaultWriteTimeOut time.Duration
  23. // DefaultMaxHeaderBytes default max header bytes
  24. DefaultMaxHeaderBytes int
  25. // PerWriteWriteTimeout timeout for writes
  26. PerWriteWriteTimeout = 30 * time.Second
  27. // PerWriteWriteTimeoutKbTime is a timeout taking account of how much there is to be written
  28. PerWriteWriteTimeoutKbTime = 10 * time.Second
  29. )
  30. // GetListener returns a listener from a GetListener function, which must have the
  31. // signature: `func FunctioName(network, address string) (net.Listener, error)`.
  32. // This determines the implementation of net.Listener which the server will use.`
  33. // It is implemented in this way so that downstreams may specify the type of listener
  34. // they want to provide Gitea on by default, such as with a hidden service or a p2p network
  35. // No need to worry about "breaking" if there would be a refactoring for the Listeners. No compatibility-guarantee for this mechanism
  36. var GetListener = DefaultGetListener
  37. func init() {
  38. DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
  39. }
  40. // ServeFunction represents a listen.Accept loop
  41. type ServeFunction = func(net.Listener) error
  42. // Server represents our graceful server
  43. type Server struct {
  44. network string
  45. address string
  46. listener net.Listener
  47. wg sync.WaitGroup
  48. state state
  49. lock *sync.RWMutex
  50. BeforeBegin func(network, address string)
  51. OnShutdown func()
  52. PerWriteTimeout time.Duration
  53. PerWritePerKbTimeout time.Duration
  54. }
  55. // NewServer creates a server on network at provided address
  56. func NewServer(network, address, name string) *Server {
  57. if GetManager().IsChild() {
  58. log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
  59. } else {
  60. log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
  61. }
  62. srv := &Server{
  63. wg: sync.WaitGroup{},
  64. state: stateInit,
  65. lock: &sync.RWMutex{},
  66. network: network,
  67. address: address,
  68. PerWriteTimeout: setting.PerWriteTimeout,
  69. PerWritePerKbTimeout: setting.PerWritePerKbTimeout,
  70. }
  71. srv.BeforeBegin = func(network, addr string) {
  72. log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
  73. }
  74. return srv
  75. }
  76. // ListenAndServe listens on the provided network address and then calls Serve
  77. // to handle requests on incoming connections.
  78. func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error {
  79. go srv.awaitShutdown()
  80. listener, err := GetListener(srv.network, srv.address)
  81. if err != nil {
  82. log.Error("Unable to GetListener: %v", err)
  83. return err
  84. }
  85. // we need to wrap the listener to take account of our lifecycle
  86. listener = newWrappedListener(listener, srv)
  87. // Now we need to take account of ProxyProtocol settings...
  88. if useProxyProtocol {
  89. listener = &proxyprotocol.Listener{
  90. Listener: listener,
  91. ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
  92. AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
  93. }
  94. }
  95. srv.listener = listener
  96. srv.BeforeBegin(srv.network, srv.address)
  97. return srv.Serve(serve)
  98. }
  99. // ListenAndServeTLSConfig listens on the provided network address and then calls
  100. // Serve to handle requests on incoming TLS connections.
  101. func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error {
  102. go srv.awaitShutdown()
  103. if tlsConfig.MinVersion == 0 {
  104. tlsConfig.MinVersion = tls.VersionTLS12
  105. }
  106. listener, err := GetListener(srv.network, srv.address)
  107. if err != nil {
  108. log.Error("Unable to get Listener: %v", err)
  109. return err
  110. }
  111. // we need to wrap the listener to take account of our lifecycle
  112. listener = newWrappedListener(listener, srv)
  113. // Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us
  114. if useProxyProtocol && !proxyProtocolTLSBridging {
  115. listener = &proxyprotocol.Listener{
  116. Listener: listener,
  117. ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
  118. AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
  119. }
  120. }
  121. // Now handle the tls protocol
  122. listener = tls.NewListener(listener, tlsConfig)
  123. // Now if we're bridging then we need the proxy to tell us who we're bridging for...
  124. if useProxyProtocol && proxyProtocolTLSBridging {
  125. listener = &proxyprotocol.Listener{
  126. Listener: listener,
  127. ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
  128. AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
  129. }
  130. }
  131. srv.listener = listener
  132. srv.BeforeBegin(srv.network, srv.address)
  133. return srv.Serve(serve)
  134. }
  135. // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
  136. // service goroutine for each. The service goroutines read requests and then call
  137. // handler to reply to them. Handler is typically nil, in which case the
  138. // DefaultServeMux is used.
  139. //
  140. // In addition to the standard Serve behaviour each connection is added to a
  141. // sync.Waitgroup so that all outstanding connections can be served before shutting
  142. // down the server.
  143. func (srv *Server) Serve(serve ServeFunction) error {
  144. defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
  145. srv.setState(stateRunning)
  146. GetManager().RegisterServer()
  147. err := serve(srv.listener)
  148. log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
  149. srv.wg.Wait()
  150. srv.setState(stateTerminate)
  151. GetManager().ServerDone()
  152. // use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
  153. if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") {
  154. return nil
  155. }
  156. return err
  157. }
  158. func (srv *Server) getState() state {
  159. srv.lock.RLock()
  160. defer srv.lock.RUnlock()
  161. return srv.state
  162. }
  163. func (srv *Server) setState(st state) {
  164. srv.lock.Lock()
  165. defer srv.lock.Unlock()
  166. srv.state = st
  167. }
  168. type filer interface {
  169. File() (*os.File, error)
  170. }
  171. type wrappedListener struct {
  172. net.Listener
  173. stopped bool
  174. server *Server
  175. }
  176. func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
  177. return &wrappedListener{
  178. Listener: l,
  179. server: srv,
  180. }
  181. }
  182. func (wl *wrappedListener) Accept() (net.Conn, error) {
  183. var c net.Conn
  184. // Set keepalive on TCPListeners connections.
  185. if tcl, ok := wl.Listener.(*net.TCPListener); ok {
  186. tc, err := tcl.AcceptTCP()
  187. if err != nil {
  188. return nil, err
  189. }
  190. _ = tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
  191. _ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
  192. c = tc
  193. } else {
  194. var err error
  195. c, err = wl.Listener.Accept()
  196. if err != nil {
  197. return nil, err
  198. }
  199. }
  200. closed := int32(0)
  201. c = &wrappedConn{
  202. Conn: c,
  203. server: wl.server,
  204. closed: &closed,
  205. perWriteTimeout: wl.server.PerWriteTimeout,
  206. perWritePerKbTimeout: wl.server.PerWritePerKbTimeout,
  207. }
  208. wl.server.wg.Add(1)
  209. return c, nil
  210. }
  211. func (wl *wrappedListener) Close() error {
  212. if wl.stopped {
  213. return syscall.EINVAL
  214. }
  215. wl.stopped = true
  216. return wl.Listener.Close()
  217. }
  218. func (wl *wrappedListener) File() (*os.File, error) {
  219. // returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
  220. return wl.Listener.(filer).File()
  221. }
  222. type wrappedConn struct {
  223. net.Conn
  224. server *Server
  225. closed *int32
  226. deadline time.Time
  227. perWriteTimeout time.Duration
  228. perWritePerKbTimeout time.Duration
  229. }
  230. func (w *wrappedConn) Write(p []byte) (n int, err error) {
  231. if w.perWriteTimeout > 0 {
  232. minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout
  233. minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout)
  234. w.deadline = w.deadline.Add(minTimeout)
  235. if minDeadline.After(w.deadline) {
  236. w.deadline = minDeadline
  237. }
  238. _ = w.Conn.SetWriteDeadline(w.deadline)
  239. }
  240. return w.Conn.Write(p)
  241. }
  242. func (w *wrappedConn) Close() error {
  243. if atomic.CompareAndSwapInt32(w.closed, 0, 1) {
  244. defer func() {
  245. if err := recover(); err != nil {
  246. select {
  247. case <-GetManager().IsHammer():
  248. // Likely deadlocked request released at hammertime
  249. log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err)
  250. default:
  251. log.Error("Panic during connection close! %v", err)
  252. }
  253. }
  254. }()
  255. w.server.wg.Done()
  256. }
  257. return w.Conn.Close()
  258. }