aboutsummaryrefslogtreecommitdiffstats
path: root/modules/ssh/ssh.go
diff options
context:
space:
mode:
Diffstat (limited to 'modules/ssh/ssh.go')
-rw-r--r--modules/ssh/ssh.go291
1 files changed, 135 insertions, 156 deletions
diff --git a/modules/ssh/ssh.go b/modules/ssh/ssh.go
index c5251ef23a..1818f33306 100644
--- a/modules/ssh/ssh.go
+++ b/modules/ssh/ssh.go
@@ -1,4 +1,3 @@
-// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2017 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
@@ -10,178 +9,157 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
+ "fmt"
"io"
- "io/ioutil"
- "net"
"os"
"os/exec"
"path/filepath"
"strings"
-
- "github.com/Unknwon/com"
- "golang.org/x/crypto/ssh"
+ "sync"
+ "syscall"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
+
+ "github.com/Unknwon/com"
+ "github.com/gliderlabs/ssh"
+ gossh "golang.org/x/crypto/ssh"
)
-func cleanCommand(cmd string) string {
- i := strings.Index(cmd, "git")
- if i == -1 {
- return cmd
+type contextKey string
+
+const giteaKeyID = contextKey("gitea-key-id")
+
+func getExitStatusFromError(err error) int {
+ if err == nil {
+ return 0
}
- return cmd[i:]
-}
-func handleServerConn(keyID string, chans <-chan ssh.NewChannel) {
- for newChan := range chans {
- if newChan.ChannelType() != "session" {
- err := newChan.Reject(ssh.UnknownChannelType, "unknown channel type")
- if err != nil {
- log.Error("Error rejecting channel: %v", err)
- }
- continue
- }
+ exitErr, ok := err.(*exec.ExitError)
+ if !ok {
+ return 1
+ }
- ch, reqs, err := newChan.Accept()
- if err != nil {
- log.Error("Error accepting channel: %v", err)
- continue
+ waitStatus, ok := exitErr.Sys().(syscall.WaitStatus)
+ if !ok {
+ // This is a fallback and should at least let us return something useful
+ // when running on Windows, even if it isn't completely accurate.
+ if exitErr.Success() {
+ return 0
}
- go func(in <-chan *ssh.Request) {
- defer func() {
- if err = ch.Close(); err != nil {
- log.Error("Close: %v", err)
- }
- }()
- for req := range in {
- payload := cleanCommand(string(req.Payload))
- switch req.Type {
- case "exec":
- cmdName := strings.TrimLeft(payload, "'()")
- log.Trace("SSH: Payload: %v", cmdName)
-
- args := []string{"serv", "key-" + keyID, "--config=" + setting.CustomConf}
- log.Trace("SSH: Arguments: %v", args)
- cmd := exec.Command(setting.AppPath, args...)
- cmd.Env = append(
- os.Environ(),
- "SSH_ORIGINAL_COMMAND="+cmdName,
- "SKIP_MINWINSVC=1",
- )
-
- stdout, err := cmd.StdoutPipe()
- if err != nil {
- log.Error("SSH: StdoutPipe: %v", err)
- return
- }
- stderr, err := cmd.StderrPipe()
- if err != nil {
- log.Error("SSH: StderrPipe: %v", err)
- return
- }
- input, err := cmd.StdinPipe()
- if err != nil {
- log.Error("SSH: StdinPipe: %v", err)
- return
- }
-
- // FIXME: check timeout
- if err = cmd.Start(); err != nil {
- log.Error("SSH: Start: %v", err)
- return
- }
-
- err = req.Reply(true, nil)
- if err != nil {
- log.Error("SSH: Reply: %v", err)
- }
- go func() {
- _, err = io.Copy(input, ch)
- if err != nil {
- log.Error("SSH: Copy: %v", err)
- }
- }()
- _, err = io.Copy(ch, stdout)
- if err != nil {
- log.Error("SSH: Copy: %v", err)
- }
- _, err = io.Copy(ch.Stderr(), stderr)
- if err != nil {
- log.Error("SSH: Copy: %v", err)
- }
-
- if err = cmd.Wait(); err != nil {
- log.Error("SSH: Wait: %v", err)
- return
- }
-
- _, err = ch.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
- if err != nil {
- log.Error("SSH: SendRequest: %v", err)
- }
- return
- default:
- }
- }
- }(reqs)
+ return 1
}
+
+ return waitStatus.ExitStatus()
}
-func listen(config *ssh.ServerConfig, host string, port int) {
- listener, err := net.Listen("tcp", host+":"+com.ToStr(port))
+func sessionHandler(session ssh.Session) {
+ keyID := session.Context().Value(giteaKeyID).(int64)
+
+ command := session.RawCommand()
+
+ log.Trace("SSH: Payload: %v", command)
+
+ args := []string{"serv", "key-" + com.ToStr(keyID), "--config=" + setting.CustomConf}
+ log.Trace("SSH: Arguments: %v", args)
+ cmd := exec.Command(setting.AppPath, args...)
+ cmd.Env = append(
+ os.Environ(),
+ "SSH_ORIGINAL_COMMAND="+command,
+ "SKIP_MINWINSVC=1",
+ )
+
+ stdout, err := cmd.StdoutPipe()
if err != nil {
- log.Fatal("Failed to start SSH server: %v", err)
+ log.Error("SSH: StdoutPipe: %v", err)
+ return
}
- for {
- // Once a ServerConfig has been configured, connections can be accepted.
- conn, err := listener.Accept()
- if err != nil {
- log.Error("SSH: Error accepting incoming connection: %v", err)
- continue
+ stderr, err := cmd.StderrPipe()
+ if err != nil {
+ log.Error("SSH: StderrPipe: %v", err)
+ return
+ }
+ stdin, err := cmd.StdinPipe()
+ if err != nil {
+ log.Error("SSH: StdinPipe: %v", err)
+ return
+ }
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ if err = cmd.Start(); err != nil {
+ log.Error("SSH: Start: %v", err)
+ return
+ }
+
+ go func() {
+ defer stdin.Close()
+ if _, err := io.Copy(stdin, session); err != nil {
+ log.Error("Failed to write session to stdin. %s", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ if _, err := io.Copy(session, stdout); err != nil {
+ log.Error("Failed to write stdout to session. %s", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ if _, err := io.Copy(session.Stderr(), stderr); err != nil {
+ log.Error("Failed to write stderr to session. %s", err)
}
+ }()
+
+ // Ensure all the output has been written before we wait on the command
+ // to exit.
+ wg.Wait()
+
+ // Wait for the command to exit and log any errors we get
+ err = cmd.Wait()
+ if err != nil {
+ log.Error("SSH: Wait: %v", err)
+ }
+
+ if err := session.Exit(getExitStatusFromError(err)); err != nil {
+ log.Error("Session failed to exit. %s", err)
+ }
+}
+
+func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
+ if ctx.User() != setting.SSH.BuiltinServerUser {
+ return false
+ }
- // Before use, a handshake must be performed on the incoming net.Conn.
- // It must be handled in a separate goroutine,
- // otherwise one user could easily block entire loop.
- // For example, user could be asked to trust server key fingerprint and hangs.
- go func() {
- log.Trace("SSH: Handshaking for %s", conn.RemoteAddr())
- sConn, chans, reqs, err := ssh.NewServerConn(conn, config)
- if err != nil {
- if err == io.EOF {
- log.Warn("SSH: Handshaking with %s was terminated: %v", conn.RemoteAddr(), err)
- } else {
- log.Error("SSH: Error on handshaking with %s: %v", conn.RemoteAddr(), err)
- }
- return
- }
-
- log.Trace("SSH: Connection from %s (%s)", sConn.RemoteAddr(), sConn.ClientVersion())
- // The incoming Request channel must be serviced.
- go ssh.DiscardRequests(reqs)
- go handleServerConn(sConn.Permissions.Extensions["key-id"], chans)
- }()
+ pkey, err := models.SearchPublicKeyByContent(strings.TrimSpace(string(gossh.MarshalAuthorizedKey(key))))
+ if err != nil {
+ log.Error("SearchPublicKeyByContent: %v", err)
+ return false
}
+
+ ctx.SetValue(giteaKeyID, pkey.ID)
+
+ return true
}
// Listen starts a SSH server listens on given port.
func Listen(host string, port int, ciphers []string, keyExchanges []string, macs []string) {
- config := &ssh.ServerConfig{
- Config: ssh.Config{
- Ciphers: ciphers,
- KeyExchanges: keyExchanges,
- MACs: macs,
- },
- PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- pkey, err := models.SearchPublicKeyByContent(strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key))))
- if err != nil {
- log.Error("SearchPublicKeyByContent: %v", err)
- return nil, err
- }
- return &ssh.Permissions{Extensions: map[string]string{"key-id": com.ToStr(pkey.ID)}}, nil
+ // TODO: Handle ciphers, keyExchanges, and macs
+
+ srv := ssh.Server{
+ Addr: fmt.Sprintf("%s:%d", host, port),
+ PublicKeyHandler: publicKeyHandler,
+ Handler: sessionHandler,
+
+ // We need to explicitly disable the PtyCallback so text displays
+ // properly.
+ PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
+ return false
},
}
@@ -197,20 +175,21 @@ func Listen(host string, port int, ciphers []string, keyExchanges []string, macs
if err != nil {
log.Fatal("Failed to generate private key: %v", err)
}
- log.Trace("SSH: New private key is generateed: %s", keyPath)
+ log.Trace("New private key is generated: %s", keyPath)
}
- privateBytes, err := ioutil.ReadFile(keyPath)
+ err := srv.SetOption(ssh.HostKeyFile(keyPath))
if err != nil {
- log.Fatal("SSH: Failed to load private key")
+ log.Error("Failed to set Host Key. %s", err)
}
- private, err := ssh.ParsePrivateKey(privateBytes)
- if err != nil {
- log.Fatal("SSH: Failed to parse private key")
- }
- config.AddHostKey(private)
- go listen(config, host, port)
+ go func() {
+ err := srv.ListenAndServe()
+ if err != nil {
+ log.Error("Failed to serve with builtin SSH server. %s", err)
+ }
+ }()
+
}
// GenKeyPair make a pair of public and private keys for SSH access.
@@ -238,12 +217,12 @@ func GenKeyPair(keyPath string) error {
}
// generate public key
- pub, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ pub, err := gossh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return err
}
- public := ssh.MarshalAuthorizedKey(pub)
+ public := gossh.MarshalAuthorizedKey(pub)
p, err := os.OpenFile(keyPath+".pub", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err