You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

net_unix.go 9.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // SPDX-License-Identifier: MIT
  3. // This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
  4. //go:build !windows
  5. package graceful
  6. import (
  7. "fmt"
  8. "net"
  9. "os"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. "time"
  14. "code.gitea.io/gitea/modules/log"
  15. "code.gitea.io/gitea/modules/setting"
  16. "code.gitea.io/gitea/modules/util"
  17. )
  18. const (
  19. listenFDsEnv = "LISTEN_FDS"
  20. startFD = 3
  21. unlinkFDsEnv = "GITEA_UNLINK_FDS"
  22. notifySocketEnv = "NOTIFY_SOCKET"
  23. watchdogTimeoutEnv = "WATCHDOG_USEC"
  24. )
  25. // In order to keep the working directory the same as when we started we record
  26. // it at startup.
  27. var originalWD, _ = os.Getwd()
  28. var (
  29. once = sync.Once{}
  30. mutex = sync.Mutex{}
  31. providedListenersToUnlink = []bool{}
  32. activeListenersToUnlink = []bool{}
  33. providedListeners = []net.Listener{}
  34. activeListeners = []net.Listener{}
  35. notifySocketAddr string
  36. watchdogTimeout time.Duration
  37. )
  38. func getProvidedFDs() (savedErr error) {
  39. // Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
  40. once.Do(func() {
  41. mutex.Lock()
  42. defer mutex.Unlock()
  43. // now handle some additional systemd provided things
  44. notifySocketAddr = os.Getenv(notifySocketEnv)
  45. if notifySocketAddr != "" {
  46. log.Debug("Systemd Notify Socket provided: %s", notifySocketAddr)
  47. savedErr = os.Unsetenv(notifySocketEnv)
  48. if savedErr != nil {
  49. log.Warn("Unable to Unset the NOTIFY_SOCKET environment variable: %v", savedErr)
  50. return
  51. }
  52. // FIXME: We don't handle WATCHDOG_PID
  53. timeoutStr := os.Getenv(watchdogTimeoutEnv)
  54. if timeoutStr != "" {
  55. savedErr = os.Unsetenv(watchdogTimeoutEnv)
  56. if savedErr != nil {
  57. log.Warn("Unable to Unset the WATCHDOG_USEC environment variable: %v", savedErr)
  58. return
  59. }
  60. s, err := strconv.ParseInt(timeoutStr, 10, 64)
  61. if err != nil {
  62. log.Error("Unable to parse the provided WATCHDOG_USEC: %v", err)
  63. savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %w", err)
  64. return
  65. }
  66. if s <= 0 {
  67. log.Error("Unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr)
  68. savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr)
  69. return
  70. }
  71. watchdogTimeout = time.Duration(s) * time.Microsecond
  72. }
  73. } else {
  74. log.Trace("No Systemd Notify Socket provided")
  75. }
  76. numFDs := os.Getenv(listenFDsEnv)
  77. if numFDs == "" {
  78. return
  79. }
  80. n, err := strconv.Atoi(numFDs)
  81. if err != nil {
  82. savedErr = fmt.Errorf("%s is not a number: %s. Err: %w", listenFDsEnv, numFDs, err)
  83. return
  84. }
  85. fdsToUnlinkStr := strings.Split(os.Getenv(unlinkFDsEnv), ",")
  86. providedListenersToUnlink = make([]bool, n)
  87. for _, fdStr := range fdsToUnlinkStr {
  88. i, err := strconv.Atoi(fdStr)
  89. if err != nil || i < 0 || i >= n {
  90. continue
  91. }
  92. providedListenersToUnlink[i] = true
  93. }
  94. for i := startFD; i < n+startFD; i++ {
  95. file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i))
  96. l, err := net.FileListener(file)
  97. if err == nil {
  98. // Close the inherited file if it's a listener
  99. if err = file.Close(); err != nil {
  100. savedErr = fmt.Errorf("error closing provided socket fd %d: %w", i, err)
  101. return
  102. }
  103. providedListeners = append(providedListeners, l)
  104. continue
  105. }
  106. // If needed we can handle packetconns here.
  107. savedErr = fmt.Errorf("Error getting provided socket fd %d: %w", i, err)
  108. return
  109. }
  110. })
  111. return savedErr
  112. }
  113. // CloseProvidedListeners closes all unused provided listeners.
  114. func CloseProvidedListeners() error {
  115. mutex.Lock()
  116. defer mutex.Unlock()
  117. var returnableError error
  118. for _, l := range providedListeners {
  119. err := l.Close()
  120. if err != nil {
  121. log.Error("Error in closing unused provided listener: %v", err)
  122. if returnableError != nil {
  123. returnableError = fmt.Errorf("%v & %w", returnableError, err)
  124. } else {
  125. returnableError = err
  126. }
  127. }
  128. }
  129. providedListeners = []net.Listener{}
  130. return returnableError
  131. }
  132. // DefaultGetListener obtains a listener for the local network address. The network must be
  133. // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It
  134. // returns an provided net.Listener for the matching network and address, or
  135. // creates a new one using net.Listen. This function can be replaced by changing the
  136. // GetListener variable at the top of this file, for example to listen on an onion service using
  137. // github.com/cretz/bine
  138. func DefaultGetListener(network, address string) (net.Listener, error) {
  139. // Add a deferral to say that we've tried to grab a listener
  140. defer GetManager().InformCleanup()
  141. switch network {
  142. case "tcp", "tcp4", "tcp6":
  143. tcpAddr, err := net.ResolveTCPAddr(network, address)
  144. if err != nil {
  145. return nil, err
  146. }
  147. return GetListenerTCP(network, tcpAddr)
  148. case "unix", "unixpacket":
  149. unixAddr, err := net.ResolveUnixAddr(network, address)
  150. if err != nil {
  151. return nil, err
  152. }
  153. return GetListenerUnix(network, unixAddr)
  154. default:
  155. return nil, net.UnknownNetworkError(network)
  156. }
  157. }
  158. // GetListenerTCP announces on the local network address. The network must be:
  159. // "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
  160. // matching network and address, or creates a new one using net.ListenTCP.
  161. func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) {
  162. if err := getProvidedFDs(); err != nil {
  163. return nil, err
  164. }
  165. mutex.Lock()
  166. defer mutex.Unlock()
  167. // look for a provided listener
  168. for i, l := range providedListeners {
  169. if isSameAddr(l.Addr(), address) {
  170. providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
  171. needsUnlink := providedListenersToUnlink[i]
  172. providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
  173. activeListeners = append(activeListeners, l)
  174. activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
  175. return l.(*net.TCPListener), nil
  176. }
  177. }
  178. // no provided listener for this address -> make a fresh listener
  179. l, err := net.ListenTCP(network, address)
  180. if err != nil {
  181. return nil, err
  182. }
  183. activeListeners = append(activeListeners, l)
  184. activeListenersToUnlink = append(activeListenersToUnlink, false)
  185. return l, nil
  186. }
  187. // GetListenerUnix announces on the local network address. The network must be:
  188. // "unix" or "unixpacket". It returns a provided net.Listener for the
  189. // matching network and address, or creates a new one using net.ListenUnix.
  190. func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) {
  191. if err := getProvidedFDs(); err != nil {
  192. return nil, err
  193. }
  194. mutex.Lock()
  195. defer mutex.Unlock()
  196. // look for a provided listener
  197. for i, l := range providedListeners {
  198. if isSameAddr(l.Addr(), address) {
  199. providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
  200. needsUnlink := providedListenersToUnlink[i]
  201. providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
  202. activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
  203. activeListeners = append(activeListeners, l)
  204. unixListener := l.(*net.UnixListener)
  205. if needsUnlink {
  206. unixListener.SetUnlinkOnClose(true)
  207. }
  208. return unixListener, nil
  209. }
  210. }
  211. // make a fresh listener
  212. if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) {
  213. return nil, fmt.Errorf("Failed to remove unix socket %s: %w", address.Name, err)
  214. }
  215. l, err := net.ListenUnix(network, address)
  216. if err != nil {
  217. return nil, err
  218. }
  219. fileMode := os.FileMode(setting.UnixSocketPermission)
  220. if err = os.Chmod(address.Name, fileMode); err != nil {
  221. return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %w", fileMode.String(), err)
  222. }
  223. activeListeners = append(activeListeners, l)
  224. activeListenersToUnlink = append(activeListenersToUnlink, true)
  225. return l, nil
  226. }
  227. func isSameAddr(a1, a2 net.Addr) bool {
  228. // If the addresses are not on the same network fail.
  229. if a1.Network() != a2.Network() {
  230. return false
  231. }
  232. // If the two addresses have the same string representation they're equal
  233. a1s := a1.String()
  234. a2s := a2.String()
  235. if a1s == a2s {
  236. return true
  237. }
  238. // This allows for ipv6 vs ipv4 local addresses to compare as equal. This
  239. // scenario is common when listening on localhost.
  240. const ipv6prefix = "[::]"
  241. a1s = strings.TrimPrefix(a1s, ipv6prefix)
  242. a2s = strings.TrimPrefix(a2s, ipv6prefix)
  243. const ipv4prefix = "0.0.0.0"
  244. a1s = strings.TrimPrefix(a1s, ipv4prefix)
  245. a2s = strings.TrimPrefix(a2s, ipv4prefix)
  246. return a1s == a2s
  247. }
  248. func getActiveListeners() []net.Listener {
  249. mutex.Lock()
  250. defer mutex.Unlock()
  251. listeners := make([]net.Listener, len(activeListeners))
  252. copy(listeners, activeListeners)
  253. return listeners
  254. }
  255. func getActiveListenersToUnlink() []bool {
  256. mutex.Lock()
  257. defer mutex.Unlock()
  258. listenersToUnlink := make([]bool, len(activeListenersToUnlink))
  259. copy(listenersToUnlink, activeListenersToUnlink)
  260. return listenersToUnlink
  261. }
  262. func getNotifySocket() (*net.UnixConn, error) {
  263. if err := getProvidedFDs(); err != nil {
  264. // This error will be logged elsewhere
  265. return nil, nil
  266. }
  267. if notifySocketAddr == "" {
  268. return nil, nil
  269. }
  270. socketAddr := &net.UnixAddr{
  271. Name: notifySocketAddr,
  272. Net: "unixgram",
  273. }
  274. notifySocket, err := net.DialUnix(socketAddr.Net, nil, socketAddr)
  275. if err != nil {
  276. log.Warn("failed to dial NOTIFY_SOCKET %s: %v", socketAddr, err)
  277. return nil, err
  278. }
  279. return notifySocket, nil
  280. }
  281. func getWatchdogTimeout() time.Duration {
  282. if err := getProvidedFDs(); err != nil {
  283. // This error will be logged elsewhere
  284. return 0
  285. }
  286. return watchdogTimeout
  287. }