You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

agent.go 2.0KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package ssh
  2. import (
  3. "io"
  4. "io/ioutil"
  5. "net"
  6. "path"
  7. "sync"
  8. gossh "golang.org/x/crypto/ssh"
  9. )
  10. const (
  11. agentRequestType = "auth-agent-req@openssh.com"
  12. agentChannelType = "auth-agent@openssh.com"
  13. agentTempDir = "auth-agent"
  14. agentListenFile = "listener.sock"
  15. )
  16. // contextKeyAgentRequest is an internal context key for storing if the
  17. // client requested agent forwarding
  18. var contextKeyAgentRequest = &contextKey{"auth-agent-req"}
  19. // SetAgentRequested sets up the session context so that AgentRequested
  20. // returns true.
  21. func SetAgentRequested(ctx Context) {
  22. ctx.SetValue(contextKeyAgentRequest, true)
  23. }
  24. // AgentRequested returns true if the client requested agent forwarding.
  25. func AgentRequested(sess Session) bool {
  26. return sess.Context().Value(contextKeyAgentRequest) == true
  27. }
  28. // NewAgentListener sets up a temporary Unix socket that can be communicated
  29. // to the session environment and used for forwarding connections.
  30. func NewAgentListener() (net.Listener, error) {
  31. dir, err := ioutil.TempDir("", agentTempDir)
  32. if err != nil {
  33. return nil, err
  34. }
  35. l, err := net.Listen("unix", path.Join(dir, agentListenFile))
  36. if err != nil {
  37. return nil, err
  38. }
  39. return l, nil
  40. }
  41. // ForwardAgentConnections takes connections from a listener to proxy into the
  42. // session on the OpenSSH channel for agent connections. It blocks and services
  43. // connections until the listener stop accepting.
  44. func ForwardAgentConnections(l net.Listener, s Session) {
  45. sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn)
  46. for {
  47. conn, err := l.Accept()
  48. if err != nil {
  49. return
  50. }
  51. go func(conn net.Conn) {
  52. defer conn.Close()
  53. channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil)
  54. if err != nil {
  55. return
  56. }
  57. defer channel.Close()
  58. go gossh.DiscardRequests(reqs)
  59. var wg sync.WaitGroup
  60. wg.Add(2)
  61. go func() {
  62. io.Copy(conn, channel)
  63. conn.(*net.UnixConn).CloseWrite()
  64. wg.Done()
  65. }()
  66. go func() {
  67. io.Copy(channel, conn)
  68. channel.CloseWrite()
  69. wg.Done()
  70. }()
  71. wg.Wait()
  72. }(conn)
  73. }
  74. }