diff options
Diffstat (limited to 'modules/ssh/ssh.go')
-rw-r--r-- | modules/ssh/ssh.go | 291 |
1 files changed, 135 insertions, 156 deletions
diff --git a/modules/ssh/ssh.go b/modules/ssh/ssh.go index c5251ef23a..1818f33306 100644 --- a/modules/ssh/ssh.go +++ b/modules/ssh/ssh.go @@ -1,4 +1,3 @@ -// Copyright 2014 The Gogs Authors. All rights reserved. // Copyright 2017 The Gitea Authors. All rights reserved. // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. @@ -10,178 +9,157 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "fmt" "io" - "io/ioutil" - "net" "os" "os/exec" "path/filepath" "strings" - - "github.com/Unknwon/com" - "golang.org/x/crypto/ssh" + "sync" + "syscall" "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + + "github.com/Unknwon/com" + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" ) -func cleanCommand(cmd string) string { - i := strings.Index(cmd, "git") - if i == -1 { - return cmd +type contextKey string + +const giteaKeyID = contextKey("gitea-key-id") + +func getExitStatusFromError(err error) int { + if err == nil { + return 0 } - return cmd[i:] -} -func handleServerConn(keyID string, chans <-chan ssh.NewChannel) { - for newChan := range chans { - if newChan.ChannelType() != "session" { - err := newChan.Reject(ssh.UnknownChannelType, "unknown channel type") - if err != nil { - log.Error("Error rejecting channel: %v", err) - } - continue - } + exitErr, ok := err.(*exec.ExitError) + if !ok { + return 1 + } - ch, reqs, err := newChan.Accept() - if err != nil { - log.Error("Error accepting channel: %v", err) - continue + waitStatus, ok := exitErr.Sys().(syscall.WaitStatus) + if !ok { + // This is a fallback and should at least let us return something useful + // when running on Windows, even if it isn't completely accurate. + if exitErr.Success() { + return 0 } - go func(in <-chan *ssh.Request) { - defer func() { - if err = ch.Close(); err != nil { - log.Error("Close: %v", err) - } - }() - for req := range in { - payload := cleanCommand(string(req.Payload)) - switch req.Type { - case "exec": - cmdName := strings.TrimLeft(payload, "'()") - log.Trace("SSH: Payload: %v", cmdName) - - args := []string{"serv", "key-" + keyID, "--config=" + setting.CustomConf} - log.Trace("SSH: Arguments: %v", args) - cmd := exec.Command(setting.AppPath, args...) - cmd.Env = append( - os.Environ(), - "SSH_ORIGINAL_COMMAND="+cmdName, - "SKIP_MINWINSVC=1", - ) - - stdout, err := cmd.StdoutPipe() - if err != nil { - log.Error("SSH: StdoutPipe: %v", err) - return - } - stderr, err := cmd.StderrPipe() - if err != nil { - log.Error("SSH: StderrPipe: %v", err) - return - } - input, err := cmd.StdinPipe() - if err != nil { - log.Error("SSH: StdinPipe: %v", err) - return - } - - // FIXME: check timeout - if err = cmd.Start(); err != nil { - log.Error("SSH: Start: %v", err) - return - } - - err = req.Reply(true, nil) - if err != nil { - log.Error("SSH: Reply: %v", err) - } - go func() { - _, err = io.Copy(input, ch) - if err != nil { - log.Error("SSH: Copy: %v", err) - } - }() - _, err = io.Copy(ch, stdout) - if err != nil { - log.Error("SSH: Copy: %v", err) - } - _, err = io.Copy(ch.Stderr(), stderr) - if err != nil { - log.Error("SSH: Copy: %v", err) - } - - if err = cmd.Wait(); err != nil { - log.Error("SSH: Wait: %v", err) - return - } - - _, err = ch.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) - if err != nil { - log.Error("SSH: SendRequest: %v", err) - } - return - default: - } - } - }(reqs) + return 1 } + + return waitStatus.ExitStatus() } -func listen(config *ssh.ServerConfig, host string, port int) { - listener, err := net.Listen("tcp", host+":"+com.ToStr(port)) +func sessionHandler(session ssh.Session) { + keyID := session.Context().Value(giteaKeyID).(int64) + + command := session.RawCommand() + + log.Trace("SSH: Payload: %v", command) + + args := []string{"serv", "key-" + com.ToStr(keyID), "--config=" + setting.CustomConf} + log.Trace("SSH: Arguments: %v", args) + cmd := exec.Command(setting.AppPath, args...) + cmd.Env = append( + os.Environ(), + "SSH_ORIGINAL_COMMAND="+command, + "SKIP_MINWINSVC=1", + ) + + stdout, err := cmd.StdoutPipe() if err != nil { - log.Fatal("Failed to start SSH server: %v", err) + log.Error("SSH: StdoutPipe: %v", err) + return } - for { - // Once a ServerConfig has been configured, connections can be accepted. - conn, err := listener.Accept() - if err != nil { - log.Error("SSH: Error accepting incoming connection: %v", err) - continue + stderr, err := cmd.StderrPipe() + if err != nil { + log.Error("SSH: StderrPipe: %v", err) + return + } + stdin, err := cmd.StdinPipe() + if err != nil { + log.Error("SSH: StdinPipe: %v", err) + return + } + + wg := &sync.WaitGroup{} + wg.Add(2) + + if err = cmd.Start(); err != nil { + log.Error("SSH: Start: %v", err) + return + } + + go func() { + defer stdin.Close() + if _, err := io.Copy(stdin, session); err != nil { + log.Error("Failed to write session to stdin. %s", err) + } + }() + + go func() { + defer wg.Done() + if _, err := io.Copy(session, stdout); err != nil { + log.Error("Failed to write stdout to session. %s", err) + } + }() + + go func() { + defer wg.Done() + if _, err := io.Copy(session.Stderr(), stderr); err != nil { + log.Error("Failed to write stderr to session. %s", err) } + }() + + // Ensure all the output has been written before we wait on the command + // to exit. + wg.Wait() + + // Wait for the command to exit and log any errors we get + err = cmd.Wait() + if err != nil { + log.Error("SSH: Wait: %v", err) + } + + if err := session.Exit(getExitStatusFromError(err)); err != nil { + log.Error("Session failed to exit. %s", err) + } +} + +func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { + if ctx.User() != setting.SSH.BuiltinServerUser { + return false + } - // Before use, a handshake must be performed on the incoming net.Conn. - // It must be handled in a separate goroutine, - // otherwise one user could easily block entire loop. - // For example, user could be asked to trust server key fingerprint and hangs. - go func() { - log.Trace("SSH: Handshaking for %s", conn.RemoteAddr()) - sConn, chans, reqs, err := ssh.NewServerConn(conn, config) - if err != nil { - if err == io.EOF { - log.Warn("SSH: Handshaking with %s was terminated: %v", conn.RemoteAddr(), err) - } else { - log.Error("SSH: Error on handshaking with %s: %v", conn.RemoteAddr(), err) - } - return - } - - log.Trace("SSH: Connection from %s (%s)", sConn.RemoteAddr(), sConn.ClientVersion()) - // The incoming Request channel must be serviced. - go ssh.DiscardRequests(reqs) - go handleServerConn(sConn.Permissions.Extensions["key-id"], chans) - }() + pkey, err := models.SearchPublicKeyByContent(strings.TrimSpace(string(gossh.MarshalAuthorizedKey(key)))) + if err != nil { + log.Error("SearchPublicKeyByContent: %v", err) + return false } + + ctx.SetValue(giteaKeyID, pkey.ID) + + return true } // Listen starts a SSH server listens on given port. func Listen(host string, port int, ciphers []string, keyExchanges []string, macs []string) { - config := &ssh.ServerConfig{ - Config: ssh.Config{ - Ciphers: ciphers, - KeyExchanges: keyExchanges, - MACs: macs, - }, - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - pkey, err := models.SearchPublicKeyByContent(strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))) - if err != nil { - log.Error("SearchPublicKeyByContent: %v", err) - return nil, err - } - return &ssh.Permissions{Extensions: map[string]string{"key-id": com.ToStr(pkey.ID)}}, nil + // TODO: Handle ciphers, keyExchanges, and macs + + srv := ssh.Server{ + Addr: fmt.Sprintf("%s:%d", host, port), + PublicKeyHandler: publicKeyHandler, + Handler: sessionHandler, + + // We need to explicitly disable the PtyCallback so text displays + // properly. + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { + return false }, } @@ -197,20 +175,21 @@ func Listen(host string, port int, ciphers []string, keyExchanges []string, macs if err != nil { log.Fatal("Failed to generate private key: %v", err) } - log.Trace("SSH: New private key is generateed: %s", keyPath) + log.Trace("New private key is generated: %s", keyPath) } - privateBytes, err := ioutil.ReadFile(keyPath) + err := srv.SetOption(ssh.HostKeyFile(keyPath)) if err != nil { - log.Fatal("SSH: Failed to load private key") + log.Error("Failed to set Host Key. %s", err) } - private, err := ssh.ParsePrivateKey(privateBytes) - if err != nil { - log.Fatal("SSH: Failed to parse private key") - } - config.AddHostKey(private) - go listen(config, host, port) + go func() { + err := srv.ListenAndServe() + if err != nil { + log.Error("Failed to serve with builtin SSH server. %s", err) + } + }() + } // GenKeyPair make a pair of public and private keys for SSH access. @@ -238,12 +217,12 @@ func GenKeyPair(keyPath string) error { } // generate public key - pub, err := ssh.NewPublicKey(&privateKey.PublicKey) + pub, err := gossh.NewPublicKey(&privateKey.PublicKey) if err != nil { return err } - public := ssh.MarshalAuthorizedKey(pub) + public := gossh.MarshalAuthorizedKey(pub) p, err := os.OpenFile(keyPath+".pub", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err |