You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.go 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. package ssh
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "time"
  9. gossh "golang.org/x/crypto/ssh"
  10. )
  11. // ErrServerClosed is returned by the Server's Serve, ListenAndServe,
  12. // and ListenAndServeTLS methods after a call to Shutdown or Close.
  13. var ErrServerClosed = errors.New("ssh: Server closed")
  14. type SubsystemHandler func(s Session)
  15. var DefaultSubsystemHandlers = map[string]SubsystemHandler{}
  16. type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
  17. var DefaultRequestHandlers = map[string]RequestHandler{}
  18. type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
  19. var DefaultChannelHandlers = map[string]ChannelHandler{
  20. "session": DefaultSessionHandler,
  21. }
  22. // Server defines parameters for running an SSH server. The zero value for
  23. // Server is a valid configuration. When both PasswordHandler and
  24. // PublicKeyHandler are nil, no client authentication is performed.
  25. type Server struct {
  26. Addr string // TCP address to listen on, ":22" if empty
  27. Handler Handler // handler to invoke, ssh.DefaultHandler if nil
  28. HostSigners []Signer // private keys for the host key, must have at least one
  29. Version string // server version to be sent before the initial handshake
  30. KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler
  31. PasswordHandler PasswordHandler // password authentication handler
  32. PublicKeyHandler PublicKeyHandler // public key authentication handler
  33. PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
  34. ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
  35. LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
  36. ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
  37. ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
  38. SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
  39. ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures
  40. IdleTimeout time.Duration // connection timeout when no activity, none if empty
  41. MaxTimeout time.Duration // absolute connection timeout, none if empty
  42. // ChannelHandlers allow overriding the built-in session handlers or provide
  43. // extensions to the protocol, such as tcpip forwarding. By default only the
  44. // "session" handler is enabled.
  45. ChannelHandlers map[string]ChannelHandler
  46. // RequestHandlers allow overriding the server-level request handlers or
  47. // provide extensions to the protocol, such as tcpip forwarding. By default
  48. // no handlers are enabled.
  49. RequestHandlers map[string]RequestHandler
  50. // SubsystemHandlers are handlers which are similar to the usual SSH command
  51. // handlers, but handle named subsystems.
  52. SubsystemHandlers map[string]SubsystemHandler
  53. listenerWg sync.WaitGroup
  54. mu sync.RWMutex
  55. listeners map[net.Listener]struct{}
  56. conns map[*gossh.ServerConn]struct{}
  57. connWg sync.WaitGroup
  58. doneChan chan struct{}
  59. }
  60. func (srv *Server) ensureHostSigner() error {
  61. srv.mu.Lock()
  62. defer srv.mu.Unlock()
  63. if len(srv.HostSigners) == 0 {
  64. signer, err := generateSigner()
  65. if err != nil {
  66. return err
  67. }
  68. srv.HostSigners = append(srv.HostSigners, signer)
  69. }
  70. return nil
  71. }
  72. func (srv *Server) ensureHandlers() {
  73. srv.mu.Lock()
  74. defer srv.mu.Unlock()
  75. if srv.RequestHandlers == nil {
  76. srv.RequestHandlers = map[string]RequestHandler{}
  77. for k, v := range DefaultRequestHandlers {
  78. srv.RequestHandlers[k] = v
  79. }
  80. }
  81. if srv.ChannelHandlers == nil {
  82. srv.ChannelHandlers = map[string]ChannelHandler{}
  83. for k, v := range DefaultChannelHandlers {
  84. srv.ChannelHandlers[k] = v
  85. }
  86. }
  87. if srv.SubsystemHandlers == nil {
  88. srv.SubsystemHandlers = map[string]SubsystemHandler{}
  89. for k, v := range DefaultSubsystemHandlers {
  90. srv.SubsystemHandlers[k] = v
  91. }
  92. }
  93. }
  94. func (srv *Server) config(ctx Context) *gossh.ServerConfig {
  95. srv.mu.RLock()
  96. defer srv.mu.RUnlock()
  97. var config *gossh.ServerConfig
  98. if srv.ServerConfigCallback == nil {
  99. config = &gossh.ServerConfig{}
  100. } else {
  101. config = srv.ServerConfigCallback(ctx)
  102. }
  103. for _, signer := range srv.HostSigners {
  104. config.AddHostKey(signer)
  105. }
  106. if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil {
  107. config.NoClientAuth = true
  108. }
  109. if srv.Version != "" {
  110. config.ServerVersion = "SSH-2.0-" + srv.Version
  111. }
  112. if srv.PasswordHandler != nil {
  113. config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
  114. applyConnMetadata(ctx, conn)
  115. if ok := srv.PasswordHandler(ctx, string(password)); !ok {
  116. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  117. }
  118. return ctx.Permissions().Permissions, nil
  119. }
  120. }
  121. if srv.PublicKeyHandler != nil {
  122. config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
  123. applyConnMetadata(ctx, conn)
  124. if ok := srv.PublicKeyHandler(ctx, key); !ok {
  125. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  126. }
  127. ctx.SetValue(ContextKeyPublicKey, key)
  128. return ctx.Permissions().Permissions, nil
  129. }
  130. }
  131. if srv.KeyboardInteractiveHandler != nil {
  132. config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) {
  133. applyConnMetadata(ctx, conn)
  134. if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok {
  135. return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
  136. }
  137. return ctx.Permissions().Permissions, nil
  138. }
  139. }
  140. return config
  141. }
  142. // Handle sets the Handler for the server.
  143. func (srv *Server) Handle(fn Handler) {
  144. srv.mu.Lock()
  145. defer srv.mu.Unlock()
  146. srv.Handler = fn
  147. }
  148. // Close immediately closes all active listeners and all active
  149. // connections.
  150. //
  151. // Close returns any error returned from closing the Server's
  152. // underlying Listener(s).
  153. func (srv *Server) Close() error {
  154. srv.mu.Lock()
  155. defer srv.mu.Unlock()
  156. srv.closeDoneChanLocked()
  157. err := srv.closeListenersLocked()
  158. for c := range srv.conns {
  159. c.Close()
  160. delete(srv.conns, c)
  161. }
  162. return err
  163. }
  164. // Shutdown gracefully shuts down the server without interrupting any
  165. // active connections. Shutdown works by first closing all open
  166. // listeners, and then waiting indefinitely for connections to close.
  167. // If the provided context expires before the shutdown is complete,
  168. // then the context's error is returned.
  169. func (srv *Server) Shutdown(ctx context.Context) error {
  170. srv.mu.Lock()
  171. lnerr := srv.closeListenersLocked()
  172. srv.closeDoneChanLocked()
  173. srv.mu.Unlock()
  174. finished := make(chan struct{}, 1)
  175. go func() {
  176. srv.listenerWg.Wait()
  177. srv.connWg.Wait()
  178. finished <- struct{}{}
  179. }()
  180. select {
  181. case <-ctx.Done():
  182. return ctx.Err()
  183. case <-finished:
  184. return lnerr
  185. }
  186. }
  187. // Serve accepts incoming connections on the Listener l, creating a new
  188. // connection goroutine for each. The connection goroutines read requests and then
  189. // calls srv.Handler to handle sessions.
  190. //
  191. // Serve always returns a non-nil error.
  192. func (srv *Server) Serve(l net.Listener) error {
  193. srv.ensureHandlers()
  194. defer l.Close()
  195. if err := srv.ensureHostSigner(); err != nil {
  196. return err
  197. }
  198. if srv.Handler == nil {
  199. srv.Handler = DefaultHandler
  200. }
  201. var tempDelay time.Duration
  202. srv.trackListener(l, true)
  203. defer srv.trackListener(l, false)
  204. for {
  205. conn, e := l.Accept()
  206. if e != nil {
  207. select {
  208. case <-srv.getDoneChan():
  209. return ErrServerClosed
  210. default:
  211. }
  212. if ne, ok := e.(net.Error); ok && ne.Temporary() {
  213. if tempDelay == 0 {
  214. tempDelay = 5 * time.Millisecond
  215. } else {
  216. tempDelay *= 2
  217. }
  218. if max := 1 * time.Second; tempDelay > max {
  219. tempDelay = max
  220. }
  221. time.Sleep(tempDelay)
  222. continue
  223. }
  224. return e
  225. }
  226. go srv.HandleConn(conn)
  227. }
  228. }
  229. func (srv *Server) HandleConn(newConn net.Conn) {
  230. ctx, cancel := newContext(srv)
  231. if srv.ConnCallback != nil {
  232. cbConn := srv.ConnCallback(ctx, newConn)
  233. if cbConn == nil {
  234. newConn.Close()
  235. return
  236. }
  237. newConn = cbConn
  238. }
  239. conn := &serverConn{
  240. Conn: newConn,
  241. idleTimeout: srv.IdleTimeout,
  242. closeCanceler: cancel,
  243. }
  244. if srv.MaxTimeout > 0 {
  245. conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
  246. }
  247. defer conn.Close()
  248. sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
  249. if err != nil {
  250. if srv.ConnectionFailedCallback != nil {
  251. srv.ConnectionFailedCallback(conn, err)
  252. }
  253. return
  254. }
  255. srv.trackConn(sshConn, true)
  256. defer srv.trackConn(sshConn, false)
  257. ctx.SetValue(ContextKeyConn, sshConn)
  258. applyConnMetadata(ctx, sshConn)
  259. //go gossh.DiscardRequests(reqs)
  260. go srv.handleRequests(ctx, reqs)
  261. for ch := range chans {
  262. handler := srv.ChannelHandlers[ch.ChannelType()]
  263. if handler == nil {
  264. handler = srv.ChannelHandlers["default"]
  265. }
  266. if handler == nil {
  267. ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
  268. continue
  269. }
  270. go handler(srv, sshConn, ch, ctx)
  271. }
  272. }
  273. func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
  274. for req := range in {
  275. handler := srv.RequestHandlers[req.Type]
  276. if handler == nil {
  277. handler = srv.RequestHandlers["default"]
  278. }
  279. if handler == nil {
  280. req.Reply(false, nil)
  281. continue
  282. }
  283. /*reqCtx, cancel := context.WithCancel(ctx)
  284. defer cancel() */
  285. ret, payload := handler(ctx, srv, req)
  286. req.Reply(ret, payload)
  287. }
  288. }
  289. // ListenAndServe listens on the TCP network address srv.Addr and then calls
  290. // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used.
  291. // ListenAndServe always returns a non-nil error.
  292. func (srv *Server) ListenAndServe() error {
  293. addr := srv.Addr
  294. if addr == "" {
  295. addr = ":22"
  296. }
  297. ln, err := net.Listen("tcp", addr)
  298. if err != nil {
  299. return err
  300. }
  301. return srv.Serve(ln)
  302. }
  303. // AddHostKey adds a private key as a host key. If an existing host key exists
  304. // with the same algorithm, it is overwritten. Each server config must have at
  305. // least one host key.
  306. func (srv *Server) AddHostKey(key Signer) {
  307. srv.mu.Lock()
  308. defer srv.mu.Unlock()
  309. // these are later added via AddHostKey on ServerConfig, which performs the
  310. // check for one of every algorithm.
  311. // This check is based on the AddHostKey method from the x/crypto/ssh
  312. // library. This allows us to only keep one active key for each type on a
  313. // server at once. So, if you're dynamically updating keys at runtime, this
  314. // list will not keep growing.
  315. for i, k := range srv.HostSigners {
  316. if k.PublicKey().Type() == key.PublicKey().Type() {
  317. srv.HostSigners[i] = key
  318. return
  319. }
  320. }
  321. srv.HostSigners = append(srv.HostSigners, key)
  322. }
  323. // SetOption runs a functional option against the server.
  324. func (srv *Server) SetOption(option Option) error {
  325. // NOTE: there is a potential race here for any option that doesn't call an
  326. // internal method. We can't actually lock here because if something calls
  327. // (as an example) AddHostKey, it will deadlock.
  328. //srv.mu.Lock()
  329. //defer srv.mu.Unlock()
  330. return option(srv)
  331. }
  332. func (srv *Server) getDoneChan() <-chan struct{} {
  333. srv.mu.Lock()
  334. defer srv.mu.Unlock()
  335. return srv.getDoneChanLocked()
  336. }
  337. func (srv *Server) getDoneChanLocked() chan struct{} {
  338. if srv.doneChan == nil {
  339. srv.doneChan = make(chan struct{})
  340. }
  341. return srv.doneChan
  342. }
  343. func (srv *Server) closeDoneChanLocked() {
  344. ch := srv.getDoneChanLocked()
  345. select {
  346. case <-ch:
  347. // Already closed. Don't close again.
  348. default:
  349. // Safe to close here. We're the only closer, guarded
  350. // by srv.mu.
  351. close(ch)
  352. }
  353. }
  354. func (srv *Server) closeListenersLocked() error {
  355. var err error
  356. for ln := range srv.listeners {
  357. if cerr := ln.Close(); cerr != nil && err == nil {
  358. err = cerr
  359. }
  360. delete(srv.listeners, ln)
  361. }
  362. return err
  363. }
  364. func (srv *Server) trackListener(ln net.Listener, add bool) {
  365. srv.mu.Lock()
  366. defer srv.mu.Unlock()
  367. if srv.listeners == nil {
  368. srv.listeners = make(map[net.Listener]struct{})
  369. }
  370. if add {
  371. // If the *Server is being reused after a previous
  372. // Close or Shutdown, reset its doneChan:
  373. if len(srv.listeners) == 0 && len(srv.conns) == 0 {
  374. srv.doneChan = nil
  375. }
  376. srv.listeners[ln] = struct{}{}
  377. srv.listenerWg.Add(1)
  378. } else {
  379. delete(srv.listeners, ln)
  380. srv.listenerWg.Done()
  381. }
  382. }
  383. func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
  384. srv.mu.Lock()
  385. defer srv.mu.Unlock()
  386. if srv.conns == nil {
  387. srv.conns = make(map[*gossh.ServerConn]struct{})
  388. }
  389. if add {
  390. srv.conns[c] = struct{}{}
  391. srv.connWg.Add(1)
  392. } else {
  393. delete(srv.conns, c)
  394. srv.connWg.Done()
  395. }
  396. }