You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

net_unix.go 6.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. // This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
  5. //go:build !windows
  6. // +build !windows
  7. package graceful
  8. import (
  9. "fmt"
  10. "net"
  11. "os"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "code.gitea.io/gitea/modules/log"
  16. "code.gitea.io/gitea/modules/setting"
  17. "code.gitea.io/gitea/modules/util"
  18. )
  19. const (
  20. listenFDs = "LISTEN_FDS"
  21. startFD = 3
  22. )
  23. // In order to keep the working directory the same as when we started we record
  24. // it at startup.
  25. var originalWD, _ = os.Getwd()
  26. var (
  27. once = sync.Once{}
  28. mutex = sync.Mutex{}
  29. providedListeners = []net.Listener{}
  30. activeListeners = []net.Listener{}
  31. )
  32. func getProvidedFDs() (savedErr error) {
  33. // Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
  34. once.Do(func() {
  35. mutex.Lock()
  36. defer mutex.Unlock()
  37. numFDs := os.Getenv(listenFDs)
  38. if numFDs == "" {
  39. return
  40. }
  41. n, err := strconv.Atoi(numFDs)
  42. if err != nil {
  43. savedErr = fmt.Errorf("%s is not a number: %s. Err: %v", listenFDs, numFDs, err)
  44. return
  45. }
  46. for i := startFD; i < n+startFD; i++ {
  47. file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i))
  48. l, err := net.FileListener(file)
  49. if err == nil {
  50. // Close the inherited file if it's a listener
  51. if err = file.Close(); err != nil {
  52. savedErr = fmt.Errorf("error closing provided socket fd %d: %s", i, err)
  53. return
  54. }
  55. providedListeners = append(providedListeners, l)
  56. continue
  57. }
  58. // If needed we can handle packetconns here.
  59. savedErr = fmt.Errorf("Error getting provided socket fd %d: %v", i, err)
  60. return
  61. }
  62. })
  63. return savedErr
  64. }
  65. // CloseProvidedListeners closes all unused provided listeners.
  66. func CloseProvidedListeners() error {
  67. mutex.Lock()
  68. defer mutex.Unlock()
  69. var returnableError error
  70. for _, l := range providedListeners {
  71. err := l.Close()
  72. if err != nil {
  73. log.Error("Error in closing unused provided listener: %v", err)
  74. if returnableError != nil {
  75. returnableError = fmt.Errorf("%v & %v", returnableError, err)
  76. } else {
  77. returnableError = err
  78. }
  79. }
  80. }
  81. providedListeners = []net.Listener{}
  82. return returnableError
  83. }
  84. // GetListener obtains a listener for the local network address. The network must be
  85. // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It
  86. // returns an provided net.Listener for the matching network and address, or
  87. // creates a new one using net.Listen.
  88. func GetListener(network, address string) (net.Listener, error) {
  89. // Add a deferral to say that we've tried to grab a listener
  90. defer GetManager().InformCleanup()
  91. switch network {
  92. case "tcp", "tcp4", "tcp6":
  93. tcpAddr, err := net.ResolveTCPAddr(network, address)
  94. if err != nil {
  95. return nil, err
  96. }
  97. return GetListenerTCP(network, tcpAddr)
  98. case "unix", "unixpacket":
  99. unixAddr, err := net.ResolveUnixAddr(network, address)
  100. if err != nil {
  101. return nil, err
  102. }
  103. return GetListenerUnix(network, unixAddr)
  104. default:
  105. return nil, net.UnknownNetworkError(network)
  106. }
  107. }
  108. // GetListenerTCP announces on the local network address. The network must be:
  109. // "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
  110. // matching network and address, or creates a new one using net.ListenTCP.
  111. func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) {
  112. if err := getProvidedFDs(); err != nil {
  113. return nil, err
  114. }
  115. mutex.Lock()
  116. defer mutex.Unlock()
  117. // look for a provided listener
  118. for i, l := range providedListeners {
  119. if isSameAddr(l.Addr(), address) {
  120. providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
  121. activeListeners = append(activeListeners, l)
  122. return l.(*net.TCPListener), nil
  123. }
  124. }
  125. // no provided listener for this address -> make a fresh listener
  126. l, err := net.ListenTCP(network, address)
  127. if err != nil {
  128. return nil, err
  129. }
  130. activeListeners = append(activeListeners, l)
  131. return l, nil
  132. }
  133. // GetListenerUnix announces on the local network address. The network must be:
  134. // "unix" or "unixpacket". It returns a provided net.Listener for the
  135. // matching network and address, or creates a new one using net.ListenUnix.
  136. func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) {
  137. if err := getProvidedFDs(); err != nil {
  138. return nil, err
  139. }
  140. mutex.Lock()
  141. defer mutex.Unlock()
  142. // look for a provided listener
  143. for i, l := range providedListeners {
  144. if isSameAddr(l.Addr(), address) {
  145. providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
  146. activeListeners = append(activeListeners, l)
  147. unixListener := l.(*net.UnixListener)
  148. unixListener.SetUnlinkOnClose(true)
  149. return unixListener, nil
  150. }
  151. }
  152. // make a fresh listener
  153. if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) {
  154. return nil, fmt.Errorf("Failed to remove unix socket %s: %v", address.Name, err)
  155. }
  156. l, err := net.ListenUnix(network, address)
  157. if err != nil {
  158. return nil, err
  159. }
  160. fileMode := os.FileMode(setting.UnixSocketPermission)
  161. if err = os.Chmod(address.Name, fileMode); err != nil {
  162. return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %v", fileMode.String(), err)
  163. }
  164. activeListeners = append(activeListeners, l)
  165. return l, nil
  166. }
  167. func isSameAddr(a1, a2 net.Addr) bool {
  168. // If the addresses are not on the same network fail.
  169. if a1.Network() != a2.Network() {
  170. return false
  171. }
  172. // If the two addresses have the same string representation they're equal
  173. a1s := a1.String()
  174. a2s := a2.String()
  175. if a1s == a2s {
  176. return true
  177. }
  178. // This allows for ipv6 vs ipv4 local addresses to compare as equal. This
  179. // scenario is common when listening on localhost.
  180. const ipv6prefix = "[::]"
  181. a1s = strings.TrimPrefix(a1s, ipv6prefix)
  182. a2s = strings.TrimPrefix(a2s, ipv6prefix)
  183. const ipv4prefix = "0.0.0.0"
  184. a1s = strings.TrimPrefix(a1s, ipv4prefix)
  185. a2s = strings.TrimPrefix(a2s, ipv4prefix)
  186. return a1s == a2s
  187. }
  188. func getActiveListeners() []net.Listener {
  189. mutex.Lock()
  190. defer mutex.Unlock()
  191. listeners := make([]net.Listener, len(activeListeners))
  192. copy(listeners, activeListeners)
  193. return listeners
  194. }