Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. // This code is highly inspired by endless go
  5. package graceful
  6. import (
  7. "crypto/tls"
  8. "net"
  9. "os"
  10. "strings"
  11. "sync"
  12. "syscall"
  13. "time"
  14. "code.gitea.io/gitea/modules/log"
  15. )
  16. type state uint8
  17. const (
  18. stateInit state = iota
  19. stateRunning
  20. stateShuttingDown
  21. stateTerminate
  22. )
  23. var (
  24. // RWMutex for when adding servers or shutting down
  25. runningServerReg sync.RWMutex
  26. // ensure we only fork once
  27. runningServersForked bool
  28. // DefaultReadTimeOut default read timeout
  29. DefaultReadTimeOut time.Duration
  30. // DefaultWriteTimeOut default write timeout
  31. DefaultWriteTimeOut time.Duration
  32. // DefaultMaxHeaderBytes default max header bytes
  33. DefaultMaxHeaderBytes int
  34. // IsChild reports if we are a fork iff LISTEN_FDS is set and our parent PID is not 1
  35. IsChild = len(os.Getenv(listenFDs)) > 0 && os.Getppid() > 1
  36. )
  37. func init() {
  38. runningServerReg = sync.RWMutex{}
  39. DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
  40. }
  41. // ServeFunction represents a listen.Accept loop
  42. type ServeFunction = func(net.Listener) error
  43. // Server represents our graceful server
  44. type Server struct {
  45. network string
  46. address string
  47. listener net.Listener
  48. PreSignalHooks map[os.Signal][]func()
  49. PostSignalHooks map[os.Signal][]func()
  50. wg sync.WaitGroup
  51. sigChan chan os.Signal
  52. state state
  53. lock *sync.RWMutex
  54. BeforeBegin func(network, address string)
  55. OnShutdown func()
  56. }
  57. // NewServer creates a server on network at provided address
  58. func NewServer(network, address string) *Server {
  59. runningServerReg.Lock()
  60. defer runningServerReg.Unlock()
  61. if IsChild {
  62. log.Info("Restarting new server: %s:%s on PID: %d", network, address, os.Getpid())
  63. } else {
  64. log.Info("Starting new server: %s:%s on PID: %d", network, address, os.Getpid())
  65. }
  66. srv := &Server{
  67. wg: sync.WaitGroup{},
  68. sigChan: make(chan os.Signal),
  69. PreSignalHooks: map[os.Signal][]func(){},
  70. PostSignalHooks: map[os.Signal][]func(){},
  71. state: stateInit,
  72. lock: &sync.RWMutex{},
  73. network: network,
  74. address: address,
  75. }
  76. srv.BeforeBegin = func(network, addr string) {
  77. log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
  78. }
  79. return srv
  80. }
  81. // ListenAndServe listens on the provided network address and then calls Serve
  82. // to handle requests on incoming connections.
  83. func (srv *Server) ListenAndServe(serve ServeFunction) error {
  84. go srv.handleSignals()
  85. l, err := GetListener(srv.network, srv.address)
  86. if err != nil {
  87. log.Error("Unable to GetListener: %v", err)
  88. return err
  89. }
  90. srv.listener = newWrappedListener(l, srv)
  91. if IsChild {
  92. _ = syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
  93. }
  94. srv.BeforeBegin(srv.network, srv.address)
  95. return srv.Serve(serve)
  96. }
  97. // ListenAndServeTLS listens on the provided network address and then calls
  98. // Serve to handle requests on incoming TLS connections.
  99. //
  100. // Filenames containing a certificate and matching private key for the server must
  101. // be provided. If the certificate is signed by a certificate authority, the
  102. // certFile should be the concatenation of the server's certificate followed by the
  103. // CA's certificate.
  104. func (srv *Server) ListenAndServeTLS(certFile, keyFile string, serve ServeFunction) error {
  105. config := &tls.Config{}
  106. if config.NextProtos == nil {
  107. config.NextProtos = []string{"http/1.1"}
  108. }
  109. config.Certificates = make([]tls.Certificate, 1)
  110. var err error
  111. config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  112. if err != nil {
  113. log.Error("Failed to load https cert file %s for %s:%s: %v", certFile, srv.network, srv.address, err)
  114. return err
  115. }
  116. return srv.ListenAndServeTLSConfig(config, serve)
  117. }
  118. // ListenAndServeTLSConfig listens on the provided network address and then calls
  119. // Serve to handle requests on incoming TLS connections.
  120. func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction) error {
  121. go srv.handleSignals()
  122. l, err := GetListener(srv.network, srv.address)
  123. if err != nil {
  124. log.Error("Unable to get Listener: %v", err)
  125. return err
  126. }
  127. wl := newWrappedListener(l, srv)
  128. srv.listener = tls.NewListener(wl, tlsConfig)
  129. if IsChild {
  130. _ = syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
  131. }
  132. srv.BeforeBegin(srv.network, srv.address)
  133. return srv.Serve(serve)
  134. }
  135. // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
  136. // service goroutine for each. The service goroutines read requests and then call
  137. // handler to reply to them. Handler is typically nil, in which case the
  138. // DefaultServeMux is used.
  139. //
  140. // In addition to the standard Serve behaviour each connection is added to a
  141. // sync.Waitgroup so that all outstanding connections can be served before shutting
  142. // down the server.
  143. func (srv *Server) Serve(serve ServeFunction) error {
  144. defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
  145. srv.setState(stateRunning)
  146. err := serve(srv.listener)
  147. log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
  148. srv.wg.Wait()
  149. srv.setState(stateTerminate)
  150. // use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
  151. if err != nil && strings.Contains(err.Error(), "use of closed") {
  152. return nil
  153. }
  154. return err
  155. }
  156. func (srv *Server) getState() state {
  157. srv.lock.RLock()
  158. defer srv.lock.RUnlock()
  159. return srv.state
  160. }
  161. func (srv *Server) setState(st state) {
  162. srv.lock.Lock()
  163. defer srv.lock.Unlock()
  164. srv.state = st
  165. }
  166. type wrappedListener struct {
  167. net.Listener
  168. stopped bool
  169. server *Server
  170. }
  171. func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
  172. return &wrappedListener{
  173. Listener: l,
  174. server: srv,
  175. }
  176. }
  177. func (wl *wrappedListener) Accept() (net.Conn, error) {
  178. var c net.Conn
  179. // Set keepalive on TCPListeners connections.
  180. if tcl, ok := wl.Listener.(*net.TCPListener); ok {
  181. tc, err := tcl.AcceptTCP()
  182. if err != nil {
  183. return nil, err
  184. }
  185. _ = tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
  186. _ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
  187. c = tc
  188. } else {
  189. var err error
  190. c, err = wl.Listener.Accept()
  191. if err != nil {
  192. return nil, err
  193. }
  194. }
  195. c = wrappedConn{
  196. Conn: c,
  197. server: wl.server,
  198. }
  199. wl.server.wg.Add(1)
  200. return c, nil
  201. }
  202. func (wl *wrappedListener) Close() error {
  203. if wl.stopped {
  204. return syscall.EINVAL
  205. }
  206. wl.stopped = true
  207. return wl.Listener.Close()
  208. }
  209. func (wl *wrappedListener) File() (*os.File, error) {
  210. // returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
  211. return wl.Listener.(filer).File()
  212. }
  213. type wrappedConn struct {
  214. net.Conn
  215. server *Server
  216. }
  217. func (w wrappedConn) Close() error {
  218. err := w.Conn.Close()
  219. if err == nil {
  220. w.server.wg.Done()
  221. }
  222. return err
  223. }