You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

net.go 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. // This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
  5. package graceful
  6. import (
  7. "fmt"
  8. "net"
  9. "os"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. "code.gitea.io/gitea/modules/log"
  14. )
  15. const (
  16. listenFDs = "LISTEN_FDS"
  17. startFD = 3
  18. )
  19. // In order to keep the working directory the same as when we started we record
  20. // it at startup.
  21. var originalWD, _ = os.Getwd()
  22. var (
  23. once = sync.Once{}
  24. mutex = sync.Mutex{}
  25. providedListeners = []net.Listener{}
  26. activeListeners = []net.Listener{}
  27. )
  28. func getProvidedFDs() (savedErr error) {
  29. // Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
  30. once.Do(func() {
  31. mutex.Lock()
  32. defer mutex.Unlock()
  33. numFDs := os.Getenv(listenFDs)
  34. if numFDs == "" {
  35. return
  36. }
  37. n, err := strconv.Atoi(numFDs)
  38. if err != nil {
  39. savedErr = fmt.Errorf("%s is not a number: %s. Err: %v", listenFDs, numFDs, err)
  40. return
  41. }
  42. for i := startFD; i < n+startFD; i++ {
  43. file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i))
  44. l, err := net.FileListener(file)
  45. if err == nil {
  46. // Close the inherited file if it's a listener
  47. if err = file.Close(); err != nil {
  48. savedErr = fmt.Errorf("error closing provided socket fd %d: %s", i, err)
  49. return
  50. }
  51. providedListeners = append(providedListeners, l)
  52. continue
  53. }
  54. // If needed we can handle packetconns here.
  55. savedErr = fmt.Errorf("Error getting provided socket fd %d: %v", i, err)
  56. return
  57. }
  58. })
  59. return savedErr
  60. }
  61. // CloseProvidedListeners closes all unused provided listeners.
  62. func CloseProvidedListeners() error {
  63. mutex.Lock()
  64. defer mutex.Unlock()
  65. var returnableError error
  66. for _, l := range providedListeners {
  67. err := l.Close()
  68. if err != nil {
  69. log.Error("Error in closing unused provided listener: %v", err)
  70. if returnableError != nil {
  71. returnableError = fmt.Errorf("%v & %v", returnableError, err)
  72. } else {
  73. returnableError = err
  74. }
  75. }
  76. }
  77. providedListeners = []net.Listener{}
  78. return returnableError
  79. }
  80. // GetListener obtains a listener for the local network address. The network must be
  81. // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It
  82. // returns an provided net.Listener for the matching network and address, or
  83. // creates a new one using net.Listen.
  84. func GetListener(network, address string) (net.Listener, error) {
  85. // Add a deferral to say that we've tried to grab a listener
  86. defer InformCleanup()
  87. switch network {
  88. case "tcp", "tcp4", "tcp6":
  89. tcpAddr, err := net.ResolveTCPAddr(network, address)
  90. if err != nil {
  91. return nil, err
  92. }
  93. return GetListenerTCP(network, tcpAddr)
  94. case "unix", "unixpacket":
  95. unixAddr, err := net.ResolveUnixAddr(network, address)
  96. if err != nil {
  97. return nil, err
  98. }
  99. return GetListenerUnix(network, unixAddr)
  100. default:
  101. return nil, net.UnknownNetworkError(network)
  102. }
  103. }
  104. // GetListenerTCP announces on the local network address. The network must be:
  105. // "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
  106. // matching network and address, or creates a new one using net.ListenTCP.
  107. func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) {
  108. if err := getProvidedFDs(); err != nil {
  109. return nil, err
  110. }
  111. mutex.Lock()
  112. defer mutex.Unlock()
  113. // look for a provided listener
  114. for i, l := range providedListeners {
  115. if isSameAddr(l.Addr(), address) {
  116. providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
  117. activeListeners = append(activeListeners, l)
  118. return l.(*net.TCPListener), nil
  119. }
  120. }
  121. // no provided listener for this address -> make a fresh listener
  122. l, err := net.ListenTCP(network, address)
  123. if err != nil {
  124. return nil, err
  125. }
  126. activeListeners = append(activeListeners, l)
  127. return l, nil
  128. }
  129. // GetListenerUnix announces on the local network address. The network must be:
  130. // "unix" or "unixpacket". It returns a provided net.Listener for the
  131. // matching network and address, or creates a new one using net.ListenUnix.
  132. func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) {
  133. if err := getProvidedFDs(); err != nil {
  134. return nil, err
  135. }
  136. mutex.Lock()
  137. defer mutex.Unlock()
  138. // look for a provided listener
  139. for i, l := range providedListeners {
  140. if isSameAddr(l.Addr(), address) {
  141. providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
  142. activeListeners = append(activeListeners, l)
  143. return l.(*net.UnixListener), nil
  144. }
  145. }
  146. // make a fresh listener
  147. l, err := net.ListenUnix(network, address)
  148. if err != nil {
  149. return nil, err
  150. }
  151. activeListeners = append(activeListeners, l)
  152. return l, nil
  153. }
  154. func isSameAddr(a1, a2 net.Addr) bool {
  155. // If the addresses are not on the same network fail.
  156. if a1.Network() != a2.Network() {
  157. return false
  158. }
  159. // If the two addresses have the same string representation they're equal
  160. a1s := a1.String()
  161. a2s := a2.String()
  162. if a1s == a2s {
  163. return true
  164. }
  165. // This allows for ipv6 vs ipv4 local addresses to compare as equal. This
  166. // scenario is common when listening on localhost.
  167. const ipv6prefix = "[::]"
  168. a1s = strings.TrimPrefix(a1s, ipv6prefix)
  169. a2s = strings.TrimPrefix(a2s, ipv6prefix)
  170. const ipv4prefix = "0.0.0.0"
  171. a1s = strings.TrimPrefix(a1s, ipv4prefix)
  172. a2s = strings.TrimPrefix(a2s, ipv4prefix)
  173. return a1s == a2s
  174. }
  175. func getActiveListeners() []net.Listener {
  176. mutex.Lock()
  177. defer mutex.Unlock()
  178. listeners := make([]net.Listener, len(activeListeners))
  179. copy(listeners, activeListeners)
  180. return listeners
  181. }