]> source.dussan.org Git - gitea.git/commitdiff
Update golang x/crypto dependencies (#2923) (#2951)
authorLauris BH <lauris@nix.lv>
Tue, 21 Nov 2017 04:25:53 +0000 (06:25 +0200)
committerGitHub <noreply@github.com>
Tue, 21 Nov 2017 04:25:53 +0000 (06:25 +0200)
* Update golang x/crypto dependencies (#2923)

* Fix govendor for x/crupto curve25519 (#2925)

31 files changed:
vendor/golang.org/x/crypto/curve25519/const_amd64.h [new file with mode: 0644]
vendor/golang.org/x/crypto/curve25519/const_amd64.s
vendor/golang.org/x/crypto/curve25519/cswap_amd64.s
vendor/golang.org/x/crypto/curve25519/curve25519.go
vendor/golang.org/x/crypto/curve25519/doc.go
vendor/golang.org/x/crypto/curve25519/freeze_amd64.s
vendor/golang.org/x/crypto/curve25519/ladderstep_amd64.s
vendor/golang.org/x/crypto/curve25519/mul_amd64.s
vendor/golang.org/x/crypto/curve25519/square_amd64.s
vendor/golang.org/x/crypto/ed25519/ed25519.go
vendor/golang.org/x/crypto/ssh/buffer.go
vendor/golang.org/x/crypto/ssh/certs.go
vendor/golang.org/x/crypto/ssh/channel.go
vendor/golang.org/x/crypto/ssh/cipher.go
vendor/golang.org/x/crypto/ssh/client.go
vendor/golang.org/x/crypto/ssh/client_auth.go
vendor/golang.org/x/crypto/ssh/common.go
vendor/golang.org/x/crypto/ssh/connection.go
vendor/golang.org/x/crypto/ssh/doc.go
vendor/golang.org/x/crypto/ssh/handshake.go
vendor/golang.org/x/crypto/ssh/kex.go
vendor/golang.org/x/crypto/ssh/keys.go
vendor/golang.org/x/crypto/ssh/mac.go
vendor/golang.org/x/crypto/ssh/messages.go
vendor/golang.org/x/crypto/ssh/mux.go
vendor/golang.org/x/crypto/ssh/server.go
vendor/golang.org/x/crypto/ssh/session.go
vendor/golang.org/x/crypto/ssh/streamlocal.go [new file with mode: 0644]
vendor/golang.org/x/crypto/ssh/tcpip.go
vendor/golang.org/x/crypto/ssh/transport.go
vendor/vendor.json

diff --git a/vendor/golang.org/x/crypto/curve25519/const_amd64.h b/vendor/golang.org/x/crypto/curve25519/const_amd64.h
new file mode 100644 (file)
index 0000000..b3f7416
--- /dev/null
@@ -0,0 +1,8 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This code was translated into a form compatible with 6a from the public
+// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
+
+#define REDMASK51     0x0007FFFFFFFFFFFF
index 797f9b051df959d3c206d2b789b6d1a0bef16488..ee7b4bd5f8e3303e829cd909f02cf30de1877acd 100644 (file)
@@ -3,12 +3,12 @@
 // license that can be found in the LICENSE file.
 
 // This code was translated into a form compatible with 6a from the public
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
+// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
 
 // +build amd64,!gccgo,!appengine
 
-DATA ·REDMASK51(SB)/8, $0x0007FFFFFFFFFFFF
-GLOBL ·REDMASK51(SB), 8, $8
+// These constants cannot be encoded in non-MOVQ immediates.
+// We access them directly from memory instead.
 
 DATA ·_121666_213(SB)/8, $996687872
 GLOBL ·_121666_213(SB), 8, $8
index 45484d1b596f07092c1db2f3b33bdb13dd95551d..cd793a5b5f2eb00a97e3a9374c03a5f71c879e19 100644 (file)
@@ -2,87 +2,64 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// This code was translated into a form compatible with 6a from the public
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
-
 // +build amd64,!gccgo,!appengine
 
-// func cswap(inout *[5]uint64, v uint64)
+// func cswap(inout *[4][5]uint64, v uint64)
 TEXT ·cswap(SB),7,$0
        MOVQ inout+0(FP),DI
        MOVQ v+8(FP),SI
 
-       CMPQ SI,$1
-       MOVQ 0(DI),SI
-       MOVQ 80(DI),DX
-       MOVQ 8(DI),CX
-       MOVQ 88(DI),R8
-       MOVQ SI,R9
-       CMOVQEQ DX,SI
-       CMOVQEQ R9,DX
-       MOVQ CX,R9
-       CMOVQEQ R8,CX
-       CMOVQEQ R9,R8
-       MOVQ SI,0(DI)
-       MOVQ DX,80(DI)
-       MOVQ CX,8(DI)
-       MOVQ R8,88(DI)
-       MOVQ 16(DI),SI
-       MOVQ 96(DI),DX
-       MOVQ 24(DI),CX
-       MOVQ 104(DI),R8
-       MOVQ SI,R9
-       CMOVQEQ DX,SI
-       CMOVQEQ R9,DX
-       MOVQ CX,R9
-       CMOVQEQ R8,CX
-       CMOVQEQ R9,R8
-       MOVQ SI,16(DI)
-       MOVQ DX,96(DI)
-       MOVQ CX,24(DI)
-       MOVQ R8,104(DI)
-       MOVQ 32(DI),SI
-       MOVQ 112(DI),DX
-       MOVQ 40(DI),CX
-       MOVQ 120(DI),R8
-       MOVQ SI,R9
-       CMOVQEQ DX,SI
-       CMOVQEQ R9,DX
-       MOVQ CX,R9
-       CMOVQEQ R8,CX
-       CMOVQEQ R9,R8
-       MOVQ SI,32(DI)
-       MOVQ DX,112(DI)
-       MOVQ CX,40(DI)
-       MOVQ R8,120(DI)
-       MOVQ 48(DI),SI
-       MOVQ 128(DI),DX
-       MOVQ 56(DI),CX
-       MOVQ 136(DI),R8
-       MOVQ SI,R9
-       CMOVQEQ DX,SI
-       CMOVQEQ R9,DX
-       MOVQ CX,R9
-       CMOVQEQ R8,CX
-       CMOVQEQ R9,R8
-       MOVQ SI,48(DI)
-       MOVQ DX,128(DI)
-       MOVQ CX,56(DI)
-       MOVQ R8,136(DI)
-       MOVQ 64(DI),SI
-       MOVQ 144(DI),DX
-       MOVQ 72(DI),CX
-       MOVQ 152(DI),R8
-       MOVQ SI,R9
-       CMOVQEQ DX,SI
-       CMOVQEQ R9,DX
-       MOVQ CX,R9
-       CMOVQEQ R8,CX
-       CMOVQEQ R9,R8
-       MOVQ SI,64(DI)
-       MOVQ DX,144(DI)
-       MOVQ CX,72(DI)
-       MOVQ R8,152(DI)
-       MOVQ DI,AX
-       MOVQ SI,DX
+       SUBQ $1, SI
+       NOTQ SI
+       MOVQ SI, X15
+       PSHUFD $0x44, X15, X15
+
+       MOVOU 0(DI), X0
+       MOVOU 16(DI), X2
+       MOVOU 32(DI), X4
+       MOVOU 48(DI), X6
+       MOVOU 64(DI), X8
+       MOVOU 80(DI), X1
+       MOVOU 96(DI), X3
+       MOVOU 112(DI), X5
+       MOVOU 128(DI), X7
+       MOVOU 144(DI), X9
+
+       MOVO X1, X10
+       MOVO X3, X11
+       MOVO X5, X12
+       MOVO X7, X13
+       MOVO X9, X14
+
+       PXOR X0, X10
+       PXOR X2, X11
+       PXOR X4, X12
+       PXOR X6, X13
+       PXOR X8, X14
+       PAND X15, X10
+       PAND X15, X11
+       PAND X15, X12
+       PAND X15, X13
+       PAND X15, X14
+       PXOR X10, X0
+       PXOR X10, X1
+       PXOR X11, X2
+       PXOR X11, X3
+       PXOR X12, X4
+       PXOR X12, X5
+       PXOR X13, X6
+       PXOR X13, X7
+       PXOR X14, X8
+       PXOR X14, X9
+
+       MOVOU X0, 0(DI)
+       MOVOU X2, 16(DI)
+       MOVOU X4, 32(DI)
+       MOVOU X6, 48(DI)
+       MOVOU X8, 64(DI)
+       MOVOU X1, 80(DI)
+       MOVOU X3, 96(DI)
+       MOVOU X5, 112(DI)
+       MOVOU X7, 128(DI)
+       MOVOU X9, 144(DI)
        RET
index 6918c47fc2eceeaf07a08dc0251efcb2e9af5958..cb8fbc57b97a6ce068dd45e58a4c5c1b72209bae 100644 (file)
@@ -2,12 +2,16 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// We have a implementation in amd64 assembly so this code is only run on
+// We have an implementation in amd64 assembly so this code is only run on
 // non-amd64 platforms. The amd64 assembly does not support gccgo.
 // +build !amd64 gccgo appengine
 
 package curve25519
 
+import (
+       "encoding/binary"
+)
+
 // This code is a port of the public domain, "ref10" implementation of
 // curve25519 from SUPERCOP 20130419 by D. J. Bernstein.
 
@@ -50,17 +54,11 @@ func feCopy(dst, src *fieldElement) {
 //
 // Preconditions: b in {0,1}.
 func feCSwap(f, g *fieldElement, b int32) {
-       var x fieldElement
        b = -b
-       for i := range x {
-               x[i] = b & (f[i] ^ g[i])
-       }
-
        for i := range f {
-               f[i] ^= x[i]
-       }
-       for i := range g {
-               g[i] ^= x[i]
+               t := b & (f[i] ^ g[i])
+               f[i] ^= t
+               g[i] ^= t
        }
 }
 
@@ -75,12 +73,7 @@ func load3(in []byte) int64 {
 
 // load4 reads a 32-bit, little-endian value from in.
 func load4(in []byte) int64 {
-       var r int64
-       r = int64(in[0])
-       r |= int64(in[1]) << 8
-       r |= int64(in[2]) << 16
-       r |= int64(in[3]) << 24
-       return r
+       return int64(binary.LittleEndian.Uint32(in))
 }
 
 func feFromBytes(dst *fieldElement, src *[32]byte) {
index ebeea3c2d6a58be6357daa32d8f54a0cae76cd88..da9b10d9c1ffd204d7d4ccd683b196c6b5fb5e67 100644 (file)
@@ -3,7 +3,7 @@
 // license that can be found in the LICENSE file.
 
 // Package curve25519 provides an implementation of scalar multiplication on
-// the elliptic curve known as curve25519. See http://cr.yp.to/ecdh.html
+// the elliptic curve known as curve25519. See https://cr.yp.to/ecdh.html
 package curve25519 // import "golang.org/x/crypto/curve25519"
 
 // basePoint is the x coordinate of the generator of the curve.
index 932800b8d1b1dd144dde90d4fd15c0243318eef3..390816106ee99879f7e0effb7cbb5f2b65931ae9 100644 (file)
@@ -3,10 +3,12 @@
 // license that can be found in the LICENSE file.
 
 // This code was translated into a form compatible with 6a from the public
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
+// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
 
 // +build amd64,!gccgo,!appengine
 
+#include "const_amd64.h"
+
 // func freeze(inout *[5]uint64)
 TEXT ·freeze(SB),7,$0-8
        MOVQ inout+0(FP), DI
@@ -16,7 +18,7 @@ TEXT ·freeze(SB),7,$0-8
        MOVQ 16(DI),CX
        MOVQ 24(DI),R8
        MOVQ 32(DI),R9
-       MOVQ ·REDMASK51(SB),AX
+       MOVQ $REDMASK51,AX
        MOVQ AX,R10
        SUBQ $18,R10
        MOVQ $3,R11
index ee7b36c36844c6530c1c3d6ca27593bd04198f00..9e9040b2502f8e6bf156531a1cbe7feead9eee10 100644 (file)
@@ -3,10 +3,12 @@
 // license that can be found in the LICENSE file.
 
 // This code was translated into a form compatible with 6a from the public
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
+// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
 
 // +build amd64,!gccgo,!appengine
 
+#include "const_amd64.h"
+
 // func ladderstep(inout *[5][5]uint64)
 TEXT ·ladderstep(SB),0,$296-8
        MOVQ inout+0(FP),DI
@@ -118,7 +120,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 72(SP)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -233,7 +235,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 32(SP)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -438,7 +440,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 72(SP)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -588,7 +590,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 32(SP)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -728,7 +730,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 152(DI)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -843,7 +845,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 192(DI)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -993,7 +995,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 32(DI)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -1143,7 +1145,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 112(SP)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
@@ -1329,7 +1331,7 @@ TEXT ·ladderstep(SB),0,$296-8
        MULQ 192(SP)
        ADDQ AX,R12
        ADCQ DX,R13
-       MOVQ ·REDMASK51(SB),DX
+       MOVQ $REDMASK51,DX
        SHLQ $13,CX:SI
        ANDQ DX,SI
        SHLQ $13,R9:R8
index 33ce57dcded44a9a7c7e457a5e6b2693d802c98e..5ce80a2e56b975ca0297dfd3cfae522b15f4c1b9 100644 (file)
@@ -3,10 +3,12 @@
 // license that can be found in the LICENSE file.
 
 // This code was translated into a form compatible with 6a from the public
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
+// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
 
 // +build amd64,!gccgo,!appengine
 
+#include "const_amd64.h"
+
 // func mul(dest, a, b *[5]uint64)
 TEXT ·mul(SB),0,$16-24
        MOVQ dest+0(FP), DI
@@ -121,7 +123,7 @@ TEXT ·mul(SB),0,$16-24
        MULQ 32(CX)
        ADDQ AX,R14
        ADCQ DX,R15
-       MOVQ ·REDMASK51(SB),SI
+       MOVQ $REDMASK51,SI
        SHLQ $13,R9:R8
        ANDQ SI,R8
        SHLQ $13,R11:R10
index 3a92804ddf380df22d98317ae5fa8f0d92bdb176..12f73734ff5aede956c2098faf473b862c72afe5 100644 (file)
@@ -3,10 +3,12 @@
 // license that can be found in the LICENSE file.
 
 // This code was translated into a form compatible with 6a from the public
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
+// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
 
 // +build amd64,!gccgo,!appengine
 
+#include "const_amd64.h"
+
 // func square(out, in *[5]uint64)
 TEXT ·square(SB),7,$0-16
        MOVQ out+0(FP), DI
@@ -84,7 +86,7 @@ TEXT ·square(SB),7,$0-16
        MULQ 32(SI)
        ADDQ AX,R13
        ADCQ DX,R14
-       MOVQ ·REDMASK51(SB),SI
+       MOVQ $REDMASK51,SI
        SHLQ $13,R8:CX
        ANDQ SI,CX
        SHLQ $13,R10:R9
index f1d95674ac3f7562ee613705d0ff281ac97e9772..4f26b49b6a872c50d55eaae07368024c0b8f50bc 100644 (file)
@@ -3,20 +3,20 @@
 // license that can be found in the LICENSE file.
 
 // Package ed25519 implements the Ed25519 signature algorithm. See
-// http://ed25519.cr.yp.to/.
+// https://ed25519.cr.yp.to/.
 //
 // These functions are also compatible with the “Ed25519” function defined in
-// https://tools.ietf.org/html/draft-irtf-cfrg-eddsa-05.
+// RFC 8032.
 package ed25519
 
 // This code is a port of the public domain, “ref10” implementation of ed25519
 // from SUPERCOP.
 
 import (
+       "bytes"
        "crypto"
        cryptorand "crypto/rand"
        "crypto/sha512"
-       "crypto/subtle"
        "errors"
        "io"
        "strconv"
@@ -177,5 +177,5 @@ func Verify(publicKey PublicKey, message, sig []byte) bool {
 
        var checkR [32]byte
        R.ToBytes(&checkR)
-       return subtle.ConstantTimeCompare(sig[:32], checkR[:]) == 1
+       return bytes.Equal(sig[:32], checkR[:])
 }
index 6931b5114fe8301e1ee955e6a2878ff10e43d4df..1ab07d078db1633be8977102b2defc0da0a6d4e8 100644 (file)
@@ -51,13 +51,12 @@ func (b *buffer) write(buf []byte) {
 }
 
 // eof closes the buffer. Reads from the buffer once all
-// the data has been consumed will receive os.EOF.
-func (b *buffer) eof() error {
+// the data has been consumed will receive io.EOF.
+func (b *buffer) eof() {
        b.Cond.L.Lock()
        b.closed = true
        b.Cond.Signal()
        b.Cond.L.Unlock()
-       return nil
 }
 
 // Read reads data from the internal buffer in buf.  Reads will block
index 6331c94d53bc991332064fb0cfbd62075337f6a5..b1f02207819f0d7a7080ba67420b9af78c5b2591 100644 (file)
@@ -251,10 +251,18 @@ type CertChecker struct {
        // for user certificates.
        SupportedCriticalOptions []string
 
-       // IsAuthority should return true if the key is recognized as
-       // an authority. This allows for certificates to be signed by other
-       // certificates.
-       IsAuthority func(auth PublicKey) bool
+       // IsUserAuthority should return true if the key is recognized as an
+       // authority for the given user certificate. This allows for
+       // certificates to be signed by other certificates. This must be set
+       // if this CertChecker will be checking user certificates.
+       IsUserAuthority func(auth PublicKey) bool
+
+       // IsHostAuthority should report whether the key is recognized as
+       // an authority for this host. This allows for certificates to be
+       // signed by other keys, and for those other keys to only be valid
+       // signers for particular hostnames. This must be set if this
+       // CertChecker will be checking host certificates.
+       IsHostAuthority func(auth PublicKey, address string) bool
 
        // Clock is used for verifying time stamps. If nil, time.Now
        // is used.
@@ -268,7 +276,7 @@ type CertChecker struct {
        // HostKeyFallback is called when CertChecker.CheckHostKey encounters a
        // public key that is not a certificate. It must implement host key
        // validation or else, if nil, all such keys are rejected.
-       HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error
+       HostKeyFallback HostKeyCallback
 
        // IsRevoked is called for each certificate so that revocation checking
        // can be implemented. It should return true if the given certificate
@@ -290,8 +298,17 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey)
        if cert.CertType != HostCert {
                return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
        }
+       if !c.IsHostAuthority(cert.SignatureKey, addr) {
+               return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
+       }
+
+       hostname, _, err := net.SplitHostPort(addr)
+       if err != nil {
+               return err
+       }
 
-       return c.CheckCert(addr, cert)
+       // Pass hostname only as principal for host certificates (consistent with OpenSSH)
+       return c.CheckCert(hostname, cert)
 }
 
 // Authenticate checks a user certificate. Authenticate can be used as
@@ -308,6 +325,9 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
        if cert.CertType != UserCert {
                return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
        }
+       if !c.IsUserAuthority(cert.SignatureKey) {
+               return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
+       }
 
        if err := c.CheckCert(conn.User(), cert); err != nil {
                return nil, err
@@ -356,10 +376,6 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
                }
        }
 
-       if !c.IsAuthority(cert.SignatureKey) {
-               return fmt.Errorf("ssh: certificate signed by unrecognized authority")
-       }
-
        clock := c.Clock
        if clock == nil {
                clock = time.Now
index 6d709b50be094cd19897e50866cac102e9d37893..195530ea0da2fc9b602f8a17c7b514cb98263166 100644 (file)
@@ -461,8 +461,8 @@ func (m *mux) newChannel(chanType string, direction channelDirection, extraData
                pending:          newBuffer(),
                extPending:       newBuffer(),
                direction:        direction,
-               incomingRequests: make(chan *Request, 16),
-               msg:              make(chan interface{}, 16),
+               incomingRequests: make(chan *Request, chanSize),
+               msg:              make(chan interface{}, chanSize),
                chanType:         chanType,
                extraData:        extraData,
                mux:              m,
index 34d3917c4f76dc5cd245b2be9917e75e8de55959..aed2b1f017fb68ed0dd7e8a6a33f4a5b2a5316cc 100644 (file)
@@ -135,6 +135,7 @@ const prefixLen = 5
 type streamPacketCipher struct {
        mac    hash.Hash
        cipher cipher.Stream
+       etm    bool
 
        // The following members are to avoid per-packet allocations.
        prefix      [prefixLen]byte
@@ -150,7 +151,14 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
                return nil, err
        }
 
-       s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
+       var encryptedPaddingLength [1]byte
+       if s.mac != nil && s.etm {
+               copy(encryptedPaddingLength[:], s.prefix[4:5])
+               s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
+       } else {
+               s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
+       }
+
        length := binary.BigEndian.Uint32(s.prefix[0:4])
        paddingLength := uint32(s.prefix[4])
 
@@ -159,7 +167,12 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
                s.mac.Reset()
                binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
                s.mac.Write(s.seqNumBytes[:])
-               s.mac.Write(s.prefix[:])
+               if s.etm {
+                       s.mac.Write(s.prefix[:4])
+                       s.mac.Write(encryptedPaddingLength[:])
+               } else {
+                       s.mac.Write(s.prefix[:])
+               }
                macSize = uint32(s.mac.Size())
        }
 
@@ -184,10 +197,17 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
        }
        mac := s.packetData[length-1:]
        data := s.packetData[:length-1]
+
+       if s.mac != nil && s.etm {
+               s.mac.Write(data)
+       }
+
        s.cipher.XORKeyStream(data, data)
 
        if s.mac != nil {
-               s.mac.Write(data)
+               if !s.etm {
+                       s.mac.Write(data)
+               }
                s.macResult = s.mac.Sum(s.macResult[:0])
                if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
                        return nil, errors.New("ssh: MAC failure")
@@ -203,7 +223,13 @@ func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Rea
                return errors.New("ssh: packet too large")
        }
 
-       paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple
+       aadlen := 0
+       if s.mac != nil && s.etm {
+               // packet length is not encrypted for EtM modes
+               aadlen = 4
+       }
+
+       paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple
        if paddingLength < 4 {
                paddingLength += packetSizeMultiple
        }
@@ -220,15 +246,37 @@ func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Rea
                s.mac.Reset()
                binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
                s.mac.Write(s.seqNumBytes[:])
+
+               if s.etm {
+                       // For EtM algorithms, the packet length must stay unencrypted,
+                       // but the following data (padding length) must be encrypted
+                       s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
+               }
+
                s.mac.Write(s.prefix[:])
-               s.mac.Write(packet)
-               s.mac.Write(padding)
+
+               if !s.etm {
+                       // For non-EtM algorithms, the algorithm is applied on unencrypted data
+                       s.mac.Write(packet)
+                       s.mac.Write(padding)
+               }
+       }
+
+       if !(s.mac != nil && s.etm) {
+               // For EtM algorithms, the padding length has already been encrypted
+               // and the packet length must remain unencrypted
+               s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
        }
 
-       s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
        s.cipher.XORKeyStream(packet, packet)
        s.cipher.XORKeyStream(padding, padding)
 
+       if s.mac != nil && s.etm {
+               // For EtM algorithms, packet and padding must be encrypted
+               s.mac.Write(packet)
+               s.mac.Write(padding)
+       }
+
        if _, err := w.Write(s.prefix[:]); err != nil {
                return err
        }
@@ -256,7 +304,7 @@ type gcmCipher struct {
        buf    []byte
 }
 
-func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) {
+func newGCMCipher(iv, key []byte) (packetCipher, error) {
        c, err := aes.NewCipher(key)
        if err != nil {
                return nil, err
@@ -344,7 +392,9 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
        c.incIV()
 
        padding := plain[0]
-       if padding < 4 || padding >= 20 {
+       if padding < 4 {
+               // padding is a byte, so it automatically satisfies
+               // the maximum size, which is 255.
                return nil, fmt.Errorf("ssh: illegal padding %d", padding)
        }
 
index 0212a20c9a135acdaf0ba8ab6b84bcf304e6b1a6..6fd1994553bc37d03ba6a5d9cbbc2f89815ba52d 100644 (file)
@@ -5,15 +5,17 @@
 package ssh
 
 import (
+       "bytes"
        "errors"
        "fmt"
        "net"
+       "os"
        "sync"
        "time"
 )
 
 // Client implements a traditional SSH client that supports shells,
-// subprocesses, port forwarding and tunneled dialing.
+// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
 type Client struct {
        Conn
 
@@ -40,7 +42,7 @@ func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
                return nil
        }
 
-       ch = make(chan NewChannel, 16)
+       ch = make(chan NewChannel, chanSize)
        c.channelHandlers[channelType] = ch
        return ch
 }
@@ -59,6 +61,7 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
                conn.forwards.closeAll()
        }()
        go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
+       go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
        return conn
 }
 
@@ -68,6 +71,11 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
 func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
        fullConf := *config
        fullConf.SetDefaults()
+       if fullConf.HostKeyCallback == nil {
+               c.Close()
+               return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
+       }
+
        conn := &connection{
                sshConn: sshConn{conn: c},
        }
@@ -97,13 +105,11 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
        c.transport = newClientTransport(
                newTransport(c.sshConn.conn, config.Rand, true /* is client */),
                c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
-       if err := c.transport.requestInitialKeyChange(); err != nil {
+       if err := c.transport.waitSession(); err != nil {
                return err
        }
 
-       // We just did the key change, so the session ID is established.
        c.sessionID = c.transport.getSessionID()
-
        return c.clientAuthenticate(config)
 }
 
@@ -175,6 +181,17 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
        return NewClient(c, chans, reqs), nil
 }
 
+// HostKeyCallback is the function type used for verifying server
+// keys.  A HostKeyCallback must return nil if the host key is OK, or
+// an error to reject it. It receives the hostname as passed to Dial
+// or NewClientConn. The remote address is the RemoteAddr of the
+// net.Conn underlying the the SSH connection.
+type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+
+// BannerCallback is the function type used for treat the banner sent by
+// the server. A BannerCallback receives the message sent by the remote server.
+type BannerCallback func(message string) error
+
 // A ClientConfig structure is used to configure a Client. It must not be
 // modified after having been passed to an SSH function.
 type ClientConfig struct {
@@ -190,10 +207,18 @@ type ClientConfig struct {
        // be used during authentication.
        Auth []AuthMethod
 
-       // HostKeyCallback, if not nil, is called during the cryptographic
-       // handshake to validate the server's host key. A nil HostKeyCallback
-       // implies that all host keys are accepted.
-       HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+       // HostKeyCallback is called during the cryptographic
+       // handshake to validate the server's host key. The client
+       // configuration must supply this callback for the connection
+       // to succeed. The functions InsecureIgnoreHostKey or
+       // FixedHostKey can be used for simplistic host key checks.
+       HostKeyCallback HostKeyCallback
+
+       // BannerCallback is called during the SSH dance to display a custom
+       // server's message. The client configuration can supply this callback to
+       // handle it as wished. The function BannerDisplayStderr can be used for
+       // simplistic display on Stderr.
+       BannerCallback BannerCallback
 
        // ClientVersion contains the version identification string that will
        // be used for the connection. If empty, a reasonable default is used.
@@ -211,3 +236,43 @@ type ClientConfig struct {
        // A Timeout of zero means no timeout.
        Timeout time.Duration
 }
+
+// InsecureIgnoreHostKey returns a function that can be used for
+// ClientConfig.HostKeyCallback to accept any host key. It should
+// not be used for production code.
+func InsecureIgnoreHostKey() HostKeyCallback {
+       return func(hostname string, remote net.Addr, key PublicKey) error {
+               return nil
+       }
+}
+
+type fixedHostKey struct {
+       key PublicKey
+}
+
+func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
+       if f.key == nil {
+               return fmt.Errorf("ssh: required host key was nil")
+       }
+       if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
+               return fmt.Errorf("ssh: host key mismatch")
+       }
+       return nil
+}
+
+// FixedHostKey returns a function for use in
+// ClientConfig.HostKeyCallback to accept only a specific host key.
+func FixedHostKey(key PublicKey) HostKeyCallback {
+       hk := &fixedHostKey{key}
+       return hk.check
+}
+
+// BannerDisplayStderr returns a function that can be used for
+// ClientConfig.BannerCallback to display banners on os.Stderr.
+func BannerDisplayStderr() BannerCallback {
+       return func(banner string) error {
+               _, err := os.Stderr.WriteString(banner)
+
+               return err
+       }
+}
index 294af0d4823f5e73fb654bb9d41e26e304356c5c..a1252cb9be12d986184168bd663ed1bc04c4273e 100644 (file)
@@ -30,8 +30,10 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
        // then any untried methods suggested by the server.
        tried := make(map[string]bool)
        var lastMethods []string
+
+       sessionID := c.transport.getSessionID()
        for auth := AuthMethod(new(noneAuth)); auth != nil; {
-               ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
+               ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand)
                if err != nil {
                        return err
                }
@@ -177,31 +179,26 @@ func (cb publicKeyCallback) method() string {
 }
 
 func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
-       // Authentication is performed in two stages. The first stage sends an
-       // enquiry to test if each key is acceptable to the remote. The second
-       // stage attempts to authenticate with the valid keys obtained in the
-       // first stage.
+       // Authentication is performed by sending an enquiry to test if a key is
+       // acceptable to the remote. If the key is acceptable, the client will
+       // attempt to authenticate with the valid key.  If not the client will repeat
+       // the process with the remaining keys.
 
        signers, err := cb()
        if err != nil {
                return false, nil, err
        }
-       var validKeys []Signer
+       var methods []string
        for _, signer := range signers {
-               if ok, err := validateKey(signer.PublicKey(), user, c); ok {
-                       validKeys = append(validKeys, signer)
-               } else {
-                       if err != nil {
-                               return false, nil, err
-                       }
+               ok, err := validateKey(signer.PublicKey(), user, c)
+               if err != nil {
+                       return false, nil, err
+               }
+               if !ok {
+                       continue
                }
-       }
 
-       // methods that may continue if this auth is not successful.
-       var methods []string
-       for _, signer := range validKeys {
                pub := signer.PublicKey()
-
                pubKey := pub.Marshal()
                sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
                        User:    user,
@@ -234,13 +231,29 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
                if err != nil {
                        return false, nil, err
                }
-               if success {
+
+               // If authentication succeeds or the list of available methods does not
+               // contain the "publickey" method, do not attempt to authenticate with any
+               // other keys.  According to RFC 4252 Section 7, the latter can occur when
+               // additional authentication methods are required.
+               if success || !containsMethod(methods, cb.method()) {
                        return success, methods, err
                }
        }
+
        return false, methods, nil
 }
 
+func containsMethod(methods []string, method string) bool {
+       for _, m := range methods {
+               if m == method {
+                       return true
+               }
+       }
+
+       return false
+}
+
 // validateKey validates the key provided is acceptable to the server.
 func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
        pubKey := key.Marshal()
@@ -270,7 +283,9 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
                }
                switch packet[0] {
                case msgUserAuthBanner:
-                       // TODO(gpaul): add callback to present the banner to the user
+                       if err := handleBannerResponse(c, packet); err != nil {
+                               return false, err
+                       }
                case msgUserAuthPubKeyOk:
                        var msg userAuthPubKeyOkMsg
                        if err := Unmarshal(packet, &msg); err != nil {
@@ -312,7 +327,9 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
 
                switch packet[0] {
                case msgUserAuthBanner:
-                       // TODO: add callback to present the banner to the user
+                       if err := handleBannerResponse(c, packet); err != nil {
+                               return false, nil, err
+                       }
                case msgUserAuthFailure:
                        var msg userAuthFailureMsg
                        if err := Unmarshal(packet, &msg); err != nil {
@@ -327,6 +344,24 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
        }
 }
 
+func handleBannerResponse(c packetConn, packet []byte) error {
+       var msg userAuthBannerMsg
+       if err := Unmarshal(packet, &msg); err != nil {
+               return err
+       }
+
+       transport, ok := c.(*handshakeTransport)
+       if !ok {
+               return nil
+       }
+
+       if transport.bannerCallback != nil {
+               return transport.bannerCallback(msg.Message)
+       }
+
+       return nil
+}
+
 // KeyboardInteractiveChallenge should print questions, optionally
 // disabling echoing (e.g. for passwords), and return all the answers.
 // Challenge may be called multiple times in a single session. After
@@ -336,7 +371,7 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
 // both CLI and GUI environments.
 type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
 
-// KeyboardInteractive returns a AuthMethod using a prompt/response
+// KeyboardInteractive returns an AuthMethod using a prompt/response
 // sequence controlled by the server.
 func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
        return challenge
@@ -372,7 +407,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
                // like handleAuthResponse, but with less options.
                switch packet[0] {
                case msgUserAuthBanner:
-                       // TODO: Print banners during userauth.
+                       if err := handleBannerResponse(c, packet); err != nil {
+                               return false, nil, err
+                       }
                        continue
                case msgUserAuthInfoRequest:
                        // OK
index 2c72ab544b07b4a8dd9602acfe7ccf080a6015fe..dc39e4d2318294d8ac3db54f9a67f49207209224 100644 (file)
@@ -9,6 +9,7 @@ import (
        "crypto/rand"
        "fmt"
        "io"
+       "math"
        "sync"
 
        _ "crypto/sha1"
@@ -40,7 +41,7 @@ var supportedKexAlgos = []string{
        kexAlgoDH14SHA1, kexAlgoDH1SHA1,
 }
 
-// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
+// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
 // of authenticating servers) in preference order.
 var supportedHostKeyAlgos = []string{
        CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
@@ -56,7 +57,7 @@ var supportedHostKeyAlgos = []string{
 // This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
 // because they have reached the end of their useful life.
 var supportedMACs = []string{
-       "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96",
+       "hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96",
 }
 
 var supportedCompressions = []string{compressionNone}
@@ -104,6 +105,21 @@ type directionAlgorithms struct {
        Compression string
 }
 
+// rekeyBytes returns a rekeying intervals in bytes.
+func (a *directionAlgorithms) rekeyBytes() int64 {
+       // According to RFC4344 block ciphers should rekey after
+       // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
+       // 128.
+       switch a.Cipher {
+       case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID:
+               return 16 * (1 << 32)
+
+       }
+
+       // For others, stick with RFC4253 recommendation to rekey after 1 Gb of data.
+       return 1 << 30
+}
+
 type algorithms struct {
        kex     string
        hostKey string
@@ -171,7 +187,7 @@ type Config struct {
 
        // The maximum number of bytes sent or received after which a
        // new key is negotiated. It must be at least 256. If
-       // unspecified, 1 gigabyte is used.
+       // unspecified, a size suitable for the chosen cipher is used.
        RekeyThreshold uint64
 
        // The allowed key exchanges algorithms. If unspecified then a
@@ -215,11 +231,12 @@ func (c *Config) SetDefaults() {
        }
 
        if c.RekeyThreshold == 0 {
-               // RFC 4253, section 9 suggests rekeying after 1G.
-               c.RekeyThreshold = 1 << 30
-       }
-       if c.RekeyThreshold < minRekeyThreshold {
+               // cipher specific default
+       } else if c.RekeyThreshold < minRekeyThreshold {
                c.RekeyThreshold = minRekeyThreshold
+       } else if c.RekeyThreshold >= math.MaxInt64 {
+               // Avoid weirdness if somebody uses -1 as a threshold.
+               c.RekeyThreshold = math.MaxInt64
        }
 }
 
index e786f2f9a20562511a303d451aef07486c143569..fd6b0681b5121fd0a8b92c7be7edc7fd5f5998b4 100644 (file)
@@ -25,7 +25,7 @@ type ConnMetadata interface {
        // User returns the user ID for this connection.
        User() string
 
-       // SessionID returns the sesson hash, also denoted by H.
+       // SessionID returns the session hash, also denoted by H.
        SessionID() []byte
 
        // ClientVersion returns the client's version string as hashed
index d6be8946629210740906a38f1f00fd9bff358122..67b7322c058058f4f794aab619d49b68b7289bd6 100644 (file)
@@ -14,5 +14,8 @@ others.
 References:
   [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
   [SSH-PARAMETERS]:    http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
+
+This package does not fall under the stability promise of the Go language itself,
+so its API may be changed when pressing needs arise.
 */
 package ssh // import "golang.org/x/crypto/ssh"
index 37d42e47f71bb93351f66b0efdfcd68c43276cd0..4f7912ecd6565bf1b5405b7c2a68d31493411697 100644 (file)
@@ -19,6 +19,11 @@ import (
 // messages are wrong when using ECDH.
 const debugHandshake = false
 
+// chanSize sets the amount of buffering SSH connections. This is
+// primarily for testing: setting chanSize=0 uncovers deadlocks more
+// quickly.
+const chanSize = 16
+
 // keyingTransport is a packet based transport that supports key
 // changes. It need not be thread-safe. It should pass through
 // msgNewKeys in both directions.
@@ -53,34 +58,65 @@ type handshakeTransport struct {
        incoming  chan []byte
        readError error
 
+       mu             sync.Mutex
+       writeError     error
+       sentInitPacket []byte
+       sentInitMsg    *kexInitMsg
+       pendingPackets [][]byte // Used when a key exchange is in progress.
+
+       // If the read loop wants to schedule a kex, it pings this
+       // channel, and the write loop will send out a kex
+       // message.
+       requestKex chan struct{}
+
+       // If the other side requests or confirms a kex, its kexInit
+       // packet is sent here for the write loop to find it.
+       startKex chan *pendingKex
+
        // data for host key checking
-       hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+       hostKeyCallback HostKeyCallback
        dialAddress     string
        remoteAddr      net.Addr
 
-       readSinceKex uint64
+       // bannerCallback is non-empty if we are the client and it has been set in
+       // ClientConfig. In that case it is called during the user authentication
+       // dance to handle a custom server's message.
+       bannerCallback BannerCallback
+
+       // Algorithms agreed in the last key exchange.
+       algorithms *algorithms
 
-       // Protects the writing side of the connection
-       mu              sync.Mutex
-       cond            *sync.Cond
-       sentInitPacket  []byte
-       sentInitMsg     *kexInitMsg
-       writtenSinceKex uint64
-       writeError      error
+       readPacketsLeft uint32
+       readBytesLeft   int64
+
+       writePacketsLeft uint32
+       writeBytesLeft   int64
 
        // The session ID or nil if first kex did not complete yet.
        sessionID []byte
 }
 
+type pendingKex struct {
+       otherInit []byte
+       done      chan error
+}
+
 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
        t := &handshakeTransport{
                conn:          conn,
                serverVersion: serverVersion,
                clientVersion: clientVersion,
-               incoming:      make(chan []byte, 16),
-               config:        config,
+               incoming:      make(chan []byte, chanSize),
+               requestKex:    make(chan struct{}, 1),
+               startKex:      make(chan *pendingKex, 1),
+
+               config: config,
        }
-       t.cond = sync.NewCond(&t.mu)
+       t.resetReadThresholds()
+       t.resetWriteThresholds()
+
+       // We always start with a mandatory key exchange.
+       t.requestKex <- struct{}{}
        return t
 }
 
@@ -89,12 +125,14 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
        t.dialAddress = dialAddr
        t.remoteAddr = addr
        t.hostKeyCallback = config.HostKeyCallback
+       t.bannerCallback = config.BannerCallback
        if config.HostKeyAlgorithms != nil {
                t.hostKeyAlgorithms = config.HostKeyAlgorithms
        } else {
                t.hostKeyAlgorithms = supportedHostKeyAlgos
        }
        go t.readLoop()
+       go t.kexLoop()
        return t
 }
 
@@ -102,6 +140,7 @@ func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt
        t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
        t.hostKeys = config.hostKeys
        go t.readLoop()
+       go t.kexLoop()
        return t
 }
 
@@ -109,6 +148,20 @@ func (t *handshakeTransport) getSessionID() []byte {
        return t.sessionID
 }
 
+// waitSession waits for the session to be established. This should be
+// the first thing to call after instantiating handshakeTransport.
+func (t *handshakeTransport) waitSession() error {
+       p, err := t.readPacket()
+       if err != nil {
+               return err
+       }
+       if p[0] != msgNewKeys {
+               return fmt.Errorf("ssh: first packet should be msgNewKeys")
+       }
+
+       return nil
+}
+
 func (t *handshakeTransport) id() string {
        if len(t.hostKeys) > 0 {
                return "server"
@@ -116,6 +169,20 @@ func (t *handshakeTransport) id() string {
        return "client"
 }
 
+func (t *handshakeTransport) printPacket(p []byte, write bool) {
+       action := "got"
+       if write {
+               action = "sent"
+       }
+
+       if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
+               log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
+       } else {
+               msg, err := decode(p)
+               log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
+       }
+}
+
 func (t *handshakeTransport) readPacket() ([]byte, error) {
        p, ok := <-t.incoming
        if !ok {
@@ -125,8 +192,10 @@ func (t *handshakeTransport) readPacket() ([]byte, error) {
 }
 
 func (t *handshakeTransport) readLoop() {
+       first := true
        for {
-               p, err := t.readOnePacket()
+               p, err := t.readOnePacket(first)
+               first = false
                if err != nil {
                        t.readError = err
                        close(t.incoming)
@@ -138,67 +207,217 @@ func (t *handshakeTransport) readLoop() {
                t.incoming <- p
        }
 
-       // If we can't read, declare the writing part dead too.
+       // Stop writers too.
+       t.recordWriteError(t.readError)
+
+       // Unblock the writer should it wait for this.
+       close(t.startKex)
+
+       // Don't close t.requestKex; it's also written to from writePacket.
+}
+
+func (t *handshakeTransport) pushPacket(p []byte) error {
+       if debugHandshake {
+               t.printPacket(p, true)
+       }
+       return t.conn.writePacket(p)
+}
+
+func (t *handshakeTransport) getWriteError() error {
        t.mu.Lock()
        defer t.mu.Unlock()
-       if t.writeError == nil {
-               t.writeError = t.readError
+       return t.writeError
+}
+
+func (t *handshakeTransport) recordWriteError(err error) {
+       t.mu.Lock()
+       defer t.mu.Unlock()
+       if t.writeError == nil && err != nil {
+               t.writeError = err
        }
-       t.cond.Broadcast()
 }
 
-func (t *handshakeTransport) readOnePacket() ([]byte, error) {
-       if t.readSinceKex > t.config.RekeyThreshold {
-               if err := t.requestKeyChange(); err != nil {
-                       return nil, err
+func (t *handshakeTransport) requestKeyExchange() {
+       select {
+       case t.requestKex <- struct{}{}:
+       default:
+               // something already requested a kex, so do nothing.
+       }
+}
+
+func (t *handshakeTransport) resetWriteThresholds() {
+       t.writePacketsLeft = packetRekeyThreshold
+       if t.config.RekeyThreshold > 0 {
+               t.writeBytesLeft = int64(t.config.RekeyThreshold)
+       } else if t.algorithms != nil {
+               t.writeBytesLeft = t.algorithms.w.rekeyBytes()
+       } else {
+               t.writeBytesLeft = 1 << 30
+       }
+}
+
+func (t *handshakeTransport) kexLoop() {
+
+write:
+       for t.getWriteError() == nil {
+               var request *pendingKex
+               var sent bool
+
+               for request == nil || !sent {
+                       var ok bool
+                       select {
+                       case request, ok = <-t.startKex:
+                               if !ok {
+                                       break write
+                               }
+                       case <-t.requestKex:
+                               break
+                       }
+
+                       if !sent {
+                               if err := t.sendKexInit(); err != nil {
+                                       t.recordWriteError(err)
+                                       break
+                               }
+                               sent = true
+                       }
+               }
+
+               if err := t.getWriteError(); err != nil {
+                       if request != nil {
+                               request.done <- err
+                       }
+                       break
+               }
+
+               // We're not servicing t.requestKex, but that is OK:
+               // we never block on sending to t.requestKex.
+
+               // We're not servicing t.startKex, but the remote end
+               // has just sent us a kexInitMsg, so it can't send
+               // another key change request, until we close the done
+               // channel on the pendingKex request.
+
+               err := t.enterKeyExchange(request.otherInit)
+
+               t.mu.Lock()
+               t.writeError = err
+               t.sentInitPacket = nil
+               t.sentInitMsg = nil
+
+               t.resetWriteThresholds()
+
+               // we have completed the key exchange. Since the
+               // reader is still blocked, it is safe to clear out
+               // the requestKex channel. This avoids the situation
+               // where: 1) we consumed our own request for the
+               // initial kex, and 2) the kex from the remote side
+               // caused another send on the requestKex channel,
+       clear:
+               for {
+                       select {
+                       case <-t.requestKex:
+                               //
+                       default:
+                               break clear
+                       }
                }
+
+               request.done <- t.writeError
+
+               // kex finished. Push packets that we received while
+               // the kex was in progress. Don't look at t.startKex
+               // and don't increment writtenSinceKex: if we trigger
+               // another kex while we are still busy with the last
+               // one, things will become very confusing.
+               for _, p := range t.pendingPackets {
+                       t.writeError = t.pushPacket(p)
+                       if t.writeError != nil {
+                               break
+                       }
+               }
+               t.pendingPackets = t.pendingPackets[:0]
+               t.mu.Unlock()
        }
 
+       // drain startKex channel. We don't service t.requestKex
+       // because nobody does blocking sends there.
+       go func() {
+               for init := range t.startKex {
+                       init.done <- t.writeError
+               }
+       }()
+
+       // Unblock reader.
+       t.conn.Close()
+}
+
+// The protocol uses uint32 for packet counters, so we can't let them
+// reach 1<<32.  We will actually read and write more packets than
+// this, though: the other side may send more packets, and after we
+// hit this limit on writing we will send a few more packets for the
+// key exchange itself.
+const packetRekeyThreshold = (1 << 31)
+
+func (t *handshakeTransport) resetReadThresholds() {
+       t.readPacketsLeft = packetRekeyThreshold
+       if t.config.RekeyThreshold > 0 {
+               t.readBytesLeft = int64(t.config.RekeyThreshold)
+       } else if t.algorithms != nil {
+               t.readBytesLeft = t.algorithms.r.rekeyBytes()
+       } else {
+               t.readBytesLeft = 1 << 30
+       }
+}
+
+func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
        p, err := t.conn.readPacket()
        if err != nil {
                return nil, err
        }
 
-       t.readSinceKex += uint64(len(p))
+       if t.readPacketsLeft > 0 {
+               t.readPacketsLeft--
+       } else {
+               t.requestKeyExchange()
+       }
+
+       if t.readBytesLeft > 0 {
+               t.readBytesLeft -= int64(len(p))
+       } else {
+               t.requestKeyExchange()
+       }
+
        if debugHandshake {
-               if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
-                       log.Printf("%s got data (packet %d bytes)", t.id(), len(p))
-               } else {
-                       msg, err := decode(p)
-                       log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
-               }
+               t.printPacket(p, false)
        }
+
+       if first && p[0] != msgKexInit {
+               return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
+       }
+
        if p[0] != msgKexInit {
                return p, nil
        }
 
-       t.mu.Lock()
-
        firstKex := t.sessionID == nil
 
-       err = t.enterKeyExchangeLocked(p)
-       if err != nil {
-               // drop connection
-               t.conn.Close()
-               t.writeError = err
+       kex := pendingKex{
+               done:      make(chan error, 1),
+               otherInit: p,
        }
+       t.startKex <- &kex
+       err = <-kex.done
 
        if debugHandshake {
                log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
        }
 
-       // Unblock writers.
-       t.sentInitMsg = nil
-       t.sentInitPacket = nil
-       t.cond.Broadcast()
-       t.writtenSinceKex = 0
-       t.mu.Unlock()
-
        if err != nil {
                return nil, err
        }
 
-       t.readSinceKex = 0
+       t.resetReadThresholds()
 
        // By default, a key exchange is hidden from higher layers by
        // translating it into msgIgnore.
@@ -213,61 +432,16 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) {
        return successPacket, nil
 }
 
-// keyChangeCategory describes whether a key exchange is the first on a
-// connection, or a subsequent one.
-type keyChangeCategory bool
-
-const (
-       firstKeyExchange      keyChangeCategory = true
-       subsequentKeyExchange keyChangeCategory = false
-)
-
-// sendKexInit sends a key change message, and returns the message
-// that was sent. After initiating the key change, all writes will be
-// blocked until the change is done, and a failed key change will
-// close the underlying transport. This function is safe for
-// concurrent use by multiple goroutines.
-func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error {
-       var err error
-
+// sendKexInit sends a key change message.
+func (t *handshakeTransport) sendKexInit() error {
        t.mu.Lock()
-       // If this is the initial key change, but we already have a sessionID,
-       // then do nothing because the key exchange has already completed
-       // asynchronously.
-       if !isFirst || t.sessionID == nil {
-               _, _, err = t.sendKexInitLocked(isFirst)
-       }
-       t.mu.Unlock()
-       if err != nil {
-               return err
-       }
-       if isFirst {
-               if packet, err := t.readPacket(); err != nil {
-                       return err
-               } else if packet[0] != msgNewKeys {
-                       return unexpectedMessageError(msgNewKeys, packet[0])
-               }
-       }
-       return nil
-}
-
-func (t *handshakeTransport) requestInitialKeyChange() error {
-       return t.sendKexInit(firstKeyExchange)
-}
-
-func (t *handshakeTransport) requestKeyChange() error {
-       return t.sendKexInit(subsequentKeyExchange)
-}
-
-// sendKexInitLocked sends a key change message. t.mu must be locked
-// while this happens.
-func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
-       // kexInits may be sent either in response to the other side,
-       // or because our side wants to initiate a key change, so we
-       // may have already sent a kexInit. In that case, don't send a
-       // second kexInit.
+       defer t.mu.Unlock()
        if t.sentInitMsg != nil {
-               return t.sentInitMsg, t.sentInitPacket, nil
+               // kexInits may be sent either in response to the other side,
+               // or because our side wants to initiate a key change, so we
+               // may have already sent a kexInit. In that case, don't send a
+               // second kexInit.
+               return nil
        }
 
        msg := &kexInitMsg{
@@ -295,53 +469,65 @@ func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexI
        packetCopy := make([]byte, len(packet))
        copy(packetCopy, packet)
 
-       if err := t.conn.writePacket(packetCopy); err != nil {
-               return nil, nil, err
+       if err := t.pushPacket(packetCopy); err != nil {
+               return err
        }
 
        t.sentInitMsg = msg
        t.sentInitPacket = packet
-       return msg, packet, nil
+
+       return nil
 }
 
 func (t *handshakeTransport) writePacket(p []byte) error {
+       switch p[0] {
+       case msgKexInit:
+               return errors.New("ssh: only handshakeTransport can send kexInit")
+       case msgNewKeys:
+               return errors.New("ssh: only handshakeTransport can send newKeys")
+       }
+
        t.mu.Lock()
        defer t.mu.Unlock()
+       if t.writeError != nil {
+               return t.writeError
+       }
 
-       if t.writtenSinceKex > t.config.RekeyThreshold {
-               t.sendKexInitLocked(subsequentKeyExchange)
+       if t.sentInitMsg != nil {
+               // Copy the packet so the writer can reuse the buffer.
+               cp := make([]byte, len(p))
+               copy(cp, p)
+               t.pendingPackets = append(t.pendingPackets, cp)
+               return nil
        }
-       for t.sentInitMsg != nil && t.writeError == nil {
-               t.cond.Wait()
+
+       if t.writeBytesLeft > 0 {
+               t.writeBytesLeft -= int64(len(p))
+       } else {
+               t.requestKeyExchange()
        }
-       if t.writeError != nil {
-               return t.writeError
+
+       if t.writePacketsLeft > 0 {
+               t.writePacketsLeft--
+       } else {
+               t.requestKeyExchange()
        }
-       t.writtenSinceKex += uint64(len(p))
 
-       switch p[0] {
-       case msgKexInit:
-               return errors.New("ssh: only handshakeTransport can send kexInit")
-       case msgNewKeys:
-               return errors.New("ssh: only handshakeTransport can send newKeys")
-       default:
-               return t.conn.writePacket(p)
+       if err := t.pushPacket(p); err != nil {
+               t.writeError = err
        }
+
+       return nil
 }
 
 func (t *handshakeTransport) Close() error {
        return t.conn.Close()
 }
 
-// enterKeyExchange runs the key exchange. t.mu must be held while running this.
-func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
+func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
        if debugHandshake {
                log.Printf("%s entered key exchange", t.id())
        }
-       myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
-       if err != nil {
-               return err
-       }
 
        otherInit := &kexInitMsg{}
        if err := Unmarshal(otherInitPacket, otherInit); err != nil {
@@ -352,20 +538,20 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
                clientVersion: t.clientVersion,
                serverVersion: t.serverVersion,
                clientKexInit: otherInitPacket,
-               serverKexInit: myInitPacket,
+               serverKexInit: t.sentInitPacket,
        }
 
        clientInit := otherInit
-       serverInit := myInit
+       serverInit := t.sentInitMsg
        if len(t.hostKeys) == 0 {
-               clientInit = myInit
-               serverInit = otherInit
+               clientInit, serverInit = serverInit, clientInit
 
-               magics.clientKexInit = myInitPacket
+               magics.clientKexInit = t.sentInitPacket
                magics.serverKexInit = otherInitPacket
        }
 
-       algs, err := findAgreedAlgorithms(clientInit, serverInit)
+       var err error
+       t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
        if err != nil {
                return err
        }
@@ -388,16 +574,16 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
                }
        }
 
-       kex, ok := kexAlgoMap[algs.kex]
+       kex, ok := kexAlgoMap[t.algorithms.kex]
        if !ok {
-               return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
+               return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
        }
 
        var result *kexResult
        if len(t.hostKeys) > 0 {
-               result, err = t.server(kex, algs, &magics)
+               result, err = t.server(kex, t.algorithms, &magics)
        } else {
-               result, err = t.client(kex, algs, &magics)
+               result, err = t.client(kex, t.algorithms, &magics)
        }
 
        if err != nil {
@@ -409,7 +595,9 @@ func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) erro
        }
        result.SessionID = t.sessionID
 
-       t.conn.prepareKeyChange(algs, result)
+       if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
+               return err
+       }
        if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
                return err
        }
@@ -449,11 +637,9 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *
                return nil, err
        }
 
-       if t.hostKeyCallback != nil {
-               err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
-               if err != nil {
-                       return nil, err
-               }
+       err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
+       if err != nil {
+               return nil, err
        }
 
        return result, nil
index c87fbebfde88a2be08fdd9440241cb1a3c4c145a..f91c2770edc2b03a5a0d041b9c763ef1d29cc815 100644 (file)
@@ -383,8 +383,8 @@ func init() {
        // 4253 and Oakley Group 2 in RFC 2409.
        p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
        kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
-               g: new(big.Int).SetInt64(2),
-               p: p,
+               g:       new(big.Int).SetInt64(2),
+               p:       p,
                pMinus1: new(big.Int).Sub(p, bigOne),
        }
 
@@ -393,8 +393,8 @@ func init() {
        p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
 
        kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
-               g: new(big.Int).SetInt64(2),
-               p: p,
+               g:       new(big.Int).SetInt64(2),
+               p:       p,
                pMinus1: new(big.Int).Sub(p, bigOne),
        }
 
index f2fc9b6c99d88d06c3991f355d4e3bd205a64bd7..b682c1741b73be40a02cb51e6202c0e12c73c166 100644 (file)
@@ -10,10 +10,13 @@ import (
        "crypto/dsa"
        "crypto/ecdsa"
        "crypto/elliptic"
+       "crypto/md5"
        "crypto/rsa"
+       "crypto/sha256"
        "crypto/x509"
        "encoding/asn1"
        "encoding/base64"
+       "encoding/hex"
        "encoding/pem"
        "errors"
        "fmt"
@@ -364,6 +367,17 @@ func (r *dsaPublicKey) Type() string {
        return "ssh-dss"
 }
 
+func checkDSAParams(param *dsa.Parameters) error {
+       // SSH specifies FIPS 186-2, which only provided a single size
+       // (1024 bits) DSA key. FIPS 186-3 allows for larger key
+       // sizes, which would confuse SSH.
+       if l := param.P.BitLen(); l != 1024 {
+               return fmt.Errorf("ssh: unsupported DSA key size %d", l)
+       }
+
+       return nil
+}
+
 // parseDSA parses an DSA key according to RFC 4253, section 6.6.
 func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
        var w struct {
@@ -374,13 +388,18 @@ func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
                return nil, nil, err
        }
 
+       param := dsa.Parameters{
+               P: w.P,
+               Q: w.Q,
+               G: w.G,
+       }
+       if err := checkDSAParams(&param); err != nil {
+               return nil, nil, err
+       }
+
        key := &dsaPublicKey{
-               Parameters: dsa.Parameters{
-                       P: w.P,
-                       Q: w.Q,
-                       G: w.G,
-               },
-               Y: w.Y,
+               Parameters: param,
+               Y:          w.Y,
        }
        return key, w.Rest, nil
 }
@@ -627,19 +646,28 @@ func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey {
 }
 
 // NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
-// *ecdsa.PrivateKey or any other crypto.Signer and returns a corresponding
-// Signer instance. ECDSA keys must use P-256, P-384 or P-521.
+// *ecdsa.PrivateKey or any other crypto.Signer and returns a
+// corresponding Signer instance. ECDSA keys must use P-256, P-384 or
+// P-521. DSA keys must use parameter size L1024N160.
 func NewSignerFromKey(key interface{}) (Signer, error) {
        switch key := key.(type) {
        case crypto.Signer:
                return NewSignerFromSigner(key)
        case *dsa.PrivateKey:
-               return &dsaPrivateKey{key}, nil
+               return newDSAPrivateKey(key)
        default:
                return nil, fmt.Errorf("ssh: unsupported key type %T", key)
        }
 }
 
+func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) {
+       if err := checkDSAParams(&key.PublicKey.Parameters); err != nil {
+               return nil, err
+       }
+
+       return &dsaPrivateKey{key}, nil
+}
+
 type wrappedSigner struct {
        signer crypto.Signer
        pubKey PublicKey
@@ -753,6 +781,18 @@ func ParsePrivateKey(pemBytes []byte) (Signer, error) {
        return NewSignerFromKey(key)
 }
 
+// ParsePrivateKeyWithPassphrase returns a Signer from a PEM encoded private
+// key and passphrase. It supports the same keys as
+// ParseRawPrivateKeyWithPassphrase.
+func ParsePrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (Signer, error) {
+       key, err := ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase)
+       if err != nil {
+               return nil, err
+       }
+
+       return NewSignerFromKey(key)
+}
+
 // encryptedBlock tells whether a private key is
 // encrypted by examining its Proc-Type header
 // for a mention of ENCRYPTED
@@ -787,6 +827,43 @@ func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
        }
 }
 
+// ParseRawPrivateKeyWithPassphrase returns a private key decrypted with
+// passphrase from a PEM encoded private key. If wrong passphrase, return
+// x509.IncorrectPasswordError.
+func ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (interface{}, error) {
+       block, _ := pem.Decode(pemBytes)
+       if block == nil {
+               return nil, errors.New("ssh: no key found")
+       }
+       buf := block.Bytes
+
+       if encryptedBlock(block) {
+               if x509.IsEncryptedPEMBlock(block) {
+                       var err error
+                       buf, err = x509.DecryptPEMBlock(block, passPhrase)
+                       if err != nil {
+                               if err == x509.IncorrectPasswordError {
+                                       return nil, err
+                               }
+                               return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err)
+                       }
+               }
+       }
+
+       switch block.Type {
+       case "RSA PRIVATE KEY":
+               return x509.ParsePKCS1PrivateKey(buf)
+       case "EC PRIVATE KEY":
+               return x509.ParseECPrivateKey(buf)
+       case "DSA PRIVATE KEY":
+               return ParseDSAPrivateKey(buf)
+       case "OPENSSH PRIVATE KEY":
+               return parseOpenSSHPrivateKey(buf)
+       default:
+               return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
+       }
+}
+
 // ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
 // specified by the OpenSSL DSA man page.
 func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
@@ -795,8 +872,8 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
                P       *big.Int
                Q       *big.Int
                G       *big.Int
-               Priv    *big.Int
                Pub     *big.Int
+               Priv    *big.Int
        }
        rest, err := asn1.Unmarshal(der, &k)
        if err != nil {
@@ -813,15 +890,15 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
                                Q: k.Q,
                                G: k.G,
                        },
-                       Y: k.Priv,
+                       Y: k.Pub,
                },
-               X: k.Pub,
+               X: k.Priv,
        }, nil
 }
 
 // Implemented based on the documentation at
 // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
-func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
+func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
        magic := append([]byte("openssh-key-v1"), 0)
        if !bytes.Equal(magic, key[0:len(magic)]) {
                return nil, errors.New("ssh: invalid openssh private key format")
@@ -841,14 +918,15 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
                return nil, err
        }
 
+       if w.KdfName != "none" || w.CipherName != "none" {
+               return nil, errors.New("ssh: cannot decode encrypted private keys")
+       }
+
        pk1 := struct {
                Check1  uint32
                Check2  uint32
                Keytype string
-               Pub     []byte
-               Priv    []byte
-               Comment string
-               Pad     []byte `ssh:"rest"`
+               Rest    []byte `ssh:"rest"`
        }{}
 
        if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil {
@@ -859,22 +937,95 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
                return nil, errors.New("ssh: checkint mismatch")
        }
 
-       // we only handle ed25519 keys currently
-       if pk1.Keytype != KeyAlgoED25519 {
-               return nil, errors.New("ssh: unhandled key type")
-       }
+       // we only handle ed25519 and rsa keys currently
+       switch pk1.Keytype {
+       case KeyAlgoRSA:
+               // https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
+               key := struct {
+                       N       *big.Int
+                       E       *big.Int
+                       D       *big.Int
+                       Iqmp    *big.Int
+                       P       *big.Int
+                       Q       *big.Int
+                       Comment string
+                       Pad     []byte `ssh:"rest"`
+               }{}
+
+               if err := Unmarshal(pk1.Rest, &key); err != nil {
+                       return nil, err
+               }
 
-       for i, b := range pk1.Pad {
-               if int(b) != i+1 {
-                       return nil, errors.New("ssh: padding not as expected")
+               for i, b := range key.Pad {
+                       if int(b) != i+1 {
+                               return nil, errors.New("ssh: padding not as expected")
+                       }
+               }
+
+               pk := &rsa.PrivateKey{
+                       PublicKey: rsa.PublicKey{
+                               N: key.N,
+                               E: int(key.E.Int64()),
+                       },
+                       D:      key.D,
+                       Primes: []*big.Int{key.P, key.Q},
                }
+
+               if err := pk.Validate(); err != nil {
+                       return nil, err
+               }
+
+               pk.Precompute()
+
+               return pk, nil
+       case KeyAlgoED25519:
+               key := struct {
+                       Pub     []byte
+                       Priv    []byte
+                       Comment string
+                       Pad     []byte `ssh:"rest"`
+               }{}
+
+               if err := Unmarshal(pk1.Rest, &key); err != nil {
+                       return nil, err
+               }
+
+               if len(key.Priv) != ed25519.PrivateKeySize {
+                       return nil, errors.New("ssh: private key unexpected length")
+               }
+
+               for i, b := range key.Pad {
+                       if int(b) != i+1 {
+                               return nil, errors.New("ssh: padding not as expected")
+                       }
+               }
+
+               pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
+               copy(pk, key.Priv)
+               return &pk, nil
+       default:
+               return nil, errors.New("ssh: unhandled key type")
        }
+}
 
-       if len(pk1.Priv) != ed25519.PrivateKeySize {
-               return nil, errors.New("ssh: private key unexpected length")
+// FingerprintLegacyMD5 returns the user presentation of the key's
+// fingerprint as described by RFC 4716 section 4.
+func FingerprintLegacyMD5(pubKey PublicKey) string {
+       md5sum := md5.Sum(pubKey.Marshal())
+       hexarray := make([]string, len(md5sum))
+       for i, c := range md5sum {
+               hexarray[i] = hex.EncodeToString([]byte{c})
        }
+       return strings.Join(hexarray, ":")
+}
 
-       pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
-       copy(pk, pk1.Priv)
-       return &pk, nil
+// FingerprintSHA256 returns the user presentation of the key's
+// fingerprint as unpadded base64 encoded sha256 hash.
+// This format was introduced from OpenSSH 6.8.
+// https://www.openssh.com/txt/release-6.8
+// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding)
+func FingerprintSHA256(pubKey PublicKey) string {
+       sha256sum := sha256.Sum256(pubKey.Marshal())
+       hash := base64.RawStdEncoding.EncodeToString(sha256sum[:])
+       return "SHA256:" + hash
 }
index 07744ad67138ac87ebfd4afe622c0d4ec33db237..c07a06285e66617febc7bac604e91ba10a967e34 100644 (file)
@@ -15,6 +15,7 @@ import (
 
 type macMode struct {
        keySize int
+       etm     bool
        new     func(key []byte) hash.Hash
 }
 
@@ -45,13 +46,16 @@ func (t truncatingMAC) Size() int {
 func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
 
 var macModes = map[string]*macMode{
-       "hmac-sha2-256": {32, func(key []byte) hash.Hash {
+       "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
                return hmac.New(sha256.New, key)
        }},
-       "hmac-sha1": {20, func(key []byte) hash.Hash {
+       "hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
+               return hmac.New(sha256.New, key)
+       }},
+       "hmac-sha1": {20, false, func(key []byte) hash.Hash {
                return hmac.New(sha1.New, key)
        }},
-       "hmac-sha1-96": {20, func(key []byte) hash.Hash {
+       "hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
                return truncatingMAC{12, hmac.New(sha1.New, key)}
        }},
 }
index e6ecd3afa5adc5843de6f7819f540d07e1ca8a67..c96e1bec591eade6d927bdd918cc9c7441335ac2 100644 (file)
@@ -23,10 +23,6 @@ const (
        msgUnimplemented = 3
        msgDebug         = 4
        msgNewKeys       = 21
-
-       // Standard authentication messages
-       msgUserAuthSuccess = 52
-       msgUserAuthBanner  = 53
 )
 
 // SSH messages:
@@ -137,6 +133,18 @@ type userAuthFailureMsg struct {
        PartialSuccess bool
 }
 
+// See RFC 4252, section 5.1
+const msgUserAuthSuccess = 52
+
+// See RFC 4252, section 5.4
+const msgUserAuthBanner = 53
+
+type userAuthBannerMsg struct {
+       Message string `sshtype:"53"`
+       // unused, but required to allow message parsing
+       Language string
+}
+
 // See RFC 4256, section 3.2
 const msgUserAuthInfoRequest = 60
 const msgUserAuthInfoResponse = 61
index f3a3ddd782fd64d5fc1c3d1c394623668ecc3a03..27a527c106bfbd6fd3aad6f8a8d2ceb3300fa8e3 100644 (file)
@@ -116,9 +116,9 @@ func (m *mux) Wait() error {
 func newMux(p packetConn) *mux {
        m := &mux{
                conn:             p,
-               incomingChannels: make(chan NewChannel, 16),
+               incomingChannels: make(chan NewChannel, chanSize),
                globalResponses:  make(chan interface{}, 1),
-               incomingRequests: make(chan *Request, 16),
+               incomingRequests: make(chan *Request, chanSize),
                errCond:          newCond(),
        }
        if debugMux {
index 37df1b30252f655b4a90818303951d2494bc8d13..148d2cb245f4073edcaedafe40ba60df5cd4c49f 100644 (file)
@@ -10,26 +10,38 @@ import (
        "fmt"
        "io"
        "net"
+       "strings"
 )
 
 // The Permissions type holds fine-grained permissions that are
-// specific to a user or a specific authentication method for a
-// user. Permissions, except for "source-address", must be enforced in
-// the server application layer, after successful authentication. The
-// Permissions are passed on in ServerConn so a server implementation
-// can honor them.
+// specific to a user or a specific authentication method for a user.
+// The Permissions value for a successful authentication attempt is
+// available in ServerConn, so it can be used to pass information from
+// the user-authentication phase to the application layer.
 type Permissions struct {
-       // Critical options restrict default permissions. Common
-       // restrictions are "source-address" and "force-command". If
-       // the server cannot enforce the restriction, or does not
-       // recognize it, the user should not authenticate.
+       // CriticalOptions indicate restrictions to the default
+       // permissions, and are typically used in conjunction with
+       // user certificates. The standard for SSH certificates
+       // defines "force-command" (only allow the given command to
+       // execute) and "source-address" (only allow connections from
+       // the given address). The SSH package currently only enforces
+       // the "source-address" critical option. It is up to server
+       // implementations to enforce other critical options, such as
+       // "force-command", by checking them after the SSH handshake
+       // is successful. In general, SSH servers should reject
+       // connections that specify critical options that are unknown
+       // or not supported.
        CriticalOptions map[string]string
 
        // Extensions are extra functionality that the server may
-       // offer on authenticated connections. Common extensions are
-       // "permit-agent-forwarding", "permit-X11-forwarding". Lack of
-       // support for an extension does not preclude authenticating a
-       // user.
+       // offer on authenticated connections. Lack of support for an
+       // extension does not preclude authenticating a user. Common
+       // extensions are "permit-agent-forwarding",
+       // "permit-X11-forwarding". The Go SSH library currently does
+       // not act on any extension, and it is up to server
+       // implementations to honor them. Extensions can be used to
+       // pass data from the authentication callbacks to the server
+       // application layer.
        Extensions map[string]string
 }
 
@@ -44,13 +56,24 @@ type ServerConfig struct {
        // authenticating.
        NoClientAuth bool
 
+       // MaxAuthTries specifies the maximum number of authentication attempts
+       // permitted per connection. If set to a negative number, the number of
+       // attempts are unlimited. If set to zero, the number of attempts are limited
+       // to 6.
+       MaxAuthTries int
+
        // PasswordCallback, if non-nil, is called when a user
        // attempts to authenticate using a password.
        PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
 
-       // PublicKeyCallback, if non-nil, is called when a client attempts public
-       // key authentication. It must return true if the given public key is
-       // valid for the given user. For example, see CertChecker.Authenticate.
+       // PublicKeyCallback, if non-nil, is called when a client
+       // offers a public key for authentication. It must return a nil error
+       // if the given public key can be used to authenticate the
+       // given user. For example, see CertChecker.Authenticate. A
+       // call to this function does not guarantee that the key
+       // offered is in fact used to authenticate. To record any data
+       // depending on the public key, store it inside a
+       // Permissions.Extensions entry.
        PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
 
        // KeyboardInteractiveCallback, if non-nil, is called when
@@ -72,6 +95,10 @@ type ServerConfig struct {
        // Note that RFC 4253 section 4.2 requires that this string start with
        // "SSH-2.0-".
        ServerVersion string
+
+       // BannerCallback, if present, is called and the return string is sent to
+       // the client after key exchange completed but before authentication.
+       BannerCallback func(conn ConnMetadata) string
 }
 
 // AddHostKey adds a private key as a host key. If an existing host
@@ -142,6 +169,10 @@ type ServerConn struct {
 func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
        fullConf := *config
        fullConf.SetDefaults()
+       if fullConf.MaxAuthTries == 0 {
+               fullConf.MaxAuthTries = 6
+       }
+
        s := &connection{
                sshConn: sshConn{conn: c},
        }
@@ -188,7 +219,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
        tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
        s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
 
-       if err := s.transport.requestInitialKeyChange(); err != nil {
+       if err := s.transport.waitSession(); err != nil {
                return nil, err
        }
 
@@ -231,7 +262,7 @@ func isAcceptableAlgo(algo string) bool {
        return false
 }
 
-func checkSourceAddress(addr net.Addr, sourceAddr string) error {
+func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
        if addr == nil {
                return errors.New("ssh: no address known for client, but source-address match required")
        }
@@ -241,33 +272,71 @@ func checkSourceAddress(addr net.Addr, sourceAddr string) error {
                return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
        }
 
-       if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
-               if bytes.Equal(allowedIP, tcpAddr.IP) {
-                       return nil
-               }
-       } else {
-               _, ipNet, err := net.ParseCIDR(sourceAddr)
-               if err != nil {
-                       return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
-               }
+       for _, sourceAddr := range strings.Split(sourceAddrs, ",") {
+               if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
+                       if allowedIP.Equal(tcpAddr.IP) {
+                               return nil
+                       }
+               } else {
+                       _, ipNet, err := net.ParseCIDR(sourceAddr)
+                       if err != nil {
+                               return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
+                       }
 
-               if ipNet.Contains(tcpAddr.IP) {
-                       return nil
+                       if ipNet.Contains(tcpAddr.IP) {
+                               return nil
+                       }
                }
        }
 
        return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
 }
 
+// ServerAuthError implements the error interface. It appends any authentication
+// errors that may occur, and is returned if all of the authentication methods
+// provided by the user failed to authenticate.
+type ServerAuthError struct {
+       // Errors contains authentication errors returned by the authentication
+       // callback methods.
+       Errors []error
+}
+
+func (l ServerAuthError) Error() string {
+       var errs []string
+       for _, err := range l.Errors {
+               errs = append(errs, err.Error())
+       }
+       return "[" + strings.Join(errs, ", ") + "]"
+}
+
 func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
-       var err error
+       sessionID := s.transport.getSessionID()
        var cache pubKeyCache
        var perms *Permissions
 
+       authFailures := 0
+       var authErrs []error
+
 userAuthLoop:
        for {
+               if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
+                       discMsg := &disconnectMsg{
+                               Reason:  2,
+                               Message: "too many authentication failures",
+                       }
+
+                       if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
+                               return nil, err
+                       }
+
+                       return nil, discMsg
+               }
+
                var userAuthReq userAuthRequestMsg
                if packet, err := s.transport.readPacket(); err != nil {
+                       if err == io.EOF {
+                               return nil, &ServerAuthError{Errors: authErrs}
+                       }
                        return nil, err
                } else if err = Unmarshal(packet, &userAuthReq); err != nil {
                        return nil, err
@@ -278,6 +347,19 @@ userAuthLoop:
                }
 
                s.user = userAuthReq.User
+
+               if authFailures == 0 && config.BannerCallback != nil {
+                       msg := config.BannerCallback(s)
+                       if msg != "" {
+                               bannerMsg := &userAuthBannerMsg{
+                                       Message: msg,
+                               }
+                               if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
+                                       return nil, err
+                               }
+                       }
+               }
+
                perms = nil
                authErr := errors.New("no auth passed yet")
 
@@ -286,6 +368,11 @@ userAuthLoop:
                        if config.NoClientAuth {
                                authErr = nil
                        }
+
+                       // allow initial attempt of 'none' without penalty
+                       if authFailures == 0 {
+                               authFailures--
+                       }
                case "password":
                        if config.PasswordCallback == nil {
                                authErr = errors.New("ssh: password auth not configured")
@@ -357,6 +444,7 @@ userAuthLoop:
                        if isQuery {
                                // The client can query if the given public key
                                // would be okay.
+
                                if len(payload) > 0 {
                                        return nil, parseError(msgUserAuthRequest)
                                }
@@ -385,7 +473,7 @@ userAuthLoop:
                                if !isAcceptableAlgo(sig.Format) {
                                        break
                                }
-                               signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
+                               signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
 
                                if err := pubKey.Verify(signedData, sig); err != nil {
                                        return nil, err
@@ -398,6 +486,8 @@ userAuthLoop:
                        authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
                }
 
+               authErrs = append(authErrs, authErr)
+
                if config.AuthLogCallback != nil {
                        config.AuthLogCallback(s, userAuthReq.Method, authErr)
                }
@@ -406,6 +496,8 @@ userAuthLoop:
                        break userAuthLoop
                }
 
+               authFailures++
+
                var failureMsg userAuthFailureMsg
                if config.PasswordCallback != nil {
                        failureMsg.Methods = append(failureMsg.Methods, "password")
@@ -421,12 +513,12 @@ userAuthLoop:
                        return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
                }
 
-               if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
+               if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
                        return nil, err
                }
        }
 
-       if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
+       if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
                return nil, err
        }
        return perms, nil
index 17e2aa85c1f9a450a8596cf09de3fb9945bd2861..cc06e03f5c1a5804ca1ef1ff7371e5eaa1e2d102 100644 (file)
@@ -231,6 +231,26 @@ func (s *Session) RequestSubsystem(subsystem string) error {
        return err
 }
 
+// RFC 4254 Section 6.7.
+type ptyWindowChangeMsg struct {
+       Columns uint32
+       Rows    uint32
+       Width   uint32
+       Height  uint32
+}
+
+// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
+func (s *Session) WindowChange(h, w int) error {
+       req := ptyWindowChangeMsg{
+               Columns: uint32(w),
+               Rows:    uint32(h),
+               Width:   uint32(w * 8),
+               Height:  uint32(h * 8),
+       }
+       _, err := s.ch.SendRequest("window-change", false, Marshal(&req))
+       return err
+}
+
 // RFC 4254 Section 6.9.
 type signalMsg struct {
        Signal string
diff --git a/vendor/golang.org/x/crypto/ssh/streamlocal.go b/vendor/golang.org/x/crypto/ssh/streamlocal.go
new file mode 100644 (file)
index 0000000..a2dccc6
--- /dev/null
@@ -0,0 +1,115 @@
+package ssh
+
+import (
+       "errors"
+       "io"
+       "net"
+)
+
+// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
+// with "direct-streamlocal@openssh.com" string.
+//
+// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
+type streamLocalChannelOpenDirectMsg struct {
+       socketPath string
+       reserved0  string
+       reserved1  uint32
+}
+
+// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
+// with "forwarded-streamlocal@openssh.com" string.
+type forwardedStreamLocalPayload struct {
+       SocketPath string
+       Reserved0  string
+}
+
+// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
+// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
+type streamLocalChannelForwardMsg struct {
+       socketPath string
+}
+
+// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
+func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
+       m := streamLocalChannelForwardMsg{
+               socketPath,
+       }
+       // send message
+       ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
+       if err != nil {
+               return nil, err
+       }
+       if !ok {
+               return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
+       }
+       ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
+
+       return &unixListener{socketPath, c, ch}, nil
+}
+
+func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
+       msg := streamLocalChannelOpenDirectMsg{
+               socketPath: socketPath,
+       }
+       ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
+       if err != nil {
+               return nil, err
+       }
+       go DiscardRequests(in)
+       return ch, err
+}
+
+type unixListener struct {
+       socketPath string
+
+       conn *Client
+       in   <-chan forward
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (l *unixListener) Accept() (net.Conn, error) {
+       s, ok := <-l.in
+       if !ok {
+               return nil, io.EOF
+       }
+       ch, incoming, err := s.newCh.Accept()
+       if err != nil {
+               return nil, err
+       }
+       go DiscardRequests(incoming)
+
+       return &chanConn{
+               Channel: ch,
+               laddr: &net.UnixAddr{
+                       Name: l.socketPath,
+                       Net:  "unix",
+               },
+               raddr: &net.UnixAddr{
+                       Name: "@",
+                       Net:  "unix",
+               },
+       }, nil
+}
+
+// Close closes the listener.
+func (l *unixListener) Close() error {
+       // this also closes the listener.
+       l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
+       m := streamLocalChannelForwardMsg{
+               l.socketPath,
+       }
+       ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
+       if err == nil && !ok {
+               err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
+       }
+       return err
+}
+
+// Addr returns the listener's network address.
+func (l *unixListener) Addr() net.Addr {
+       return &net.UnixAddr{
+               Name: l.socketPath,
+               Net:  "unix",
+       }
+}
index 6151241ff08e3fd4d5b9ad6c66159ba044ced239..acf17175dffb1f504dc64f668e5b2cfe0a6ebb6a 100644 (file)
@@ -20,12 +20,20 @@ import (
 // addr. Incoming connections will be available by calling Accept on
 // the returned net.Listener. The listener must be serviced, or the
 // SSH connection may hang.
+// N must be "tcp", "tcp4", "tcp6", or "unix".
 func (c *Client) Listen(n, addr string) (net.Listener, error) {
-       laddr, err := net.ResolveTCPAddr(n, addr)
-       if err != nil {
-               return nil, err
+       switch n {
+       case "tcp", "tcp4", "tcp6":
+               laddr, err := net.ResolveTCPAddr(n, addr)
+               if err != nil {
+                       return nil, err
+               }
+               return c.ListenTCP(laddr)
+       case "unix":
+               return c.ListenUnix(addr)
+       default:
+               return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
        }
-       return c.ListenTCP(laddr)
 }
 
 // Automatic port allocation is broken with OpenSSH before 6.0. See
@@ -116,7 +124,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
        }
 
        // Register this forward, using the port number we obtained.
-       ch := c.forwards.add(*laddr)
+       ch := c.forwards.add(laddr)
 
        return &tcpListener{laddr, c, ch}, nil
 }
@@ -131,7 +139,7 @@ type forwardList struct {
 // forwardEntry represents an established mapping of a laddr on a
 // remote ssh server to a channel connected to a tcpListener.
 type forwardEntry struct {
-       laddr net.TCPAddr
+       laddr net.Addr
        c     chan forward
 }
 
@@ -139,16 +147,16 @@ type forwardEntry struct {
 // arguments to add/remove/lookup should be address as specified in
 // the original forward-request.
 type forward struct {
-       newCh NewChannel   // the ssh client channel underlying this forward
-       raddr *net.TCPAddr // the raddr of the incoming connection
+       newCh NewChannel // the ssh client channel underlying this forward
+       raddr net.Addr   // the raddr of the incoming connection
 }
 
-func (l *forwardList) add(addr net.TCPAddr) chan forward {
+func (l *forwardList) add(addr net.Addr) chan forward {
        l.Lock()
        defer l.Unlock()
        f := forwardEntry{
-               addr,
-               make(chan forward, 1),
+               laddr: addr,
+               c:     make(chan forward, 1),
        }
        l.entries = append(l.entries, f)
        return f.c
@@ -176,44 +184,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
 
 func (l *forwardList) handleChannels(in <-chan NewChannel) {
        for ch := range in {
-               var payload forwardedTCPPayload
-               if err := Unmarshal(ch.ExtraData(), &payload); err != nil {
-                       ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
-                       continue
+               var (
+                       laddr net.Addr
+                       raddr net.Addr
+                       err   error
+               )
+               switch channelType := ch.ChannelType(); channelType {
+               case "forwarded-tcpip":
+                       var payload forwardedTCPPayload
+                       if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
+                               ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
+                               continue
+                       }
+
+                       // RFC 4254 section 7.2 specifies that incoming
+                       // addresses should list the address, in string
+                       // format. It is implied that this should be an IP
+                       // address, as it would be impossible to connect to it
+                       // otherwise.
+                       laddr, err = parseTCPAddr(payload.Addr, payload.Port)
+                       if err != nil {
+                               ch.Reject(ConnectionFailed, err.Error())
+                               continue
+                       }
+                       raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
+                       if err != nil {
+                               ch.Reject(ConnectionFailed, err.Error())
+                               continue
+                       }
+
+               case "forwarded-streamlocal@openssh.com":
+                       var payload forwardedStreamLocalPayload
+                       if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
+                               ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
+                               continue
+                       }
+                       laddr = &net.UnixAddr{
+                               Name: payload.SocketPath,
+                               Net:  "unix",
+                       }
+                       raddr = &net.UnixAddr{
+                               Name: "@",
+                               Net:  "unix",
+                       }
+               default:
+                       panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
                }
-
-               // RFC 4254 section 7.2 specifies that incoming
-               // addresses should list the address, in string
-               // format. It is implied that this should be an IP
-               // address, as it would be impossible to connect to it
-               // otherwise.
-               laddr, err := parseTCPAddr(payload.Addr, payload.Port)
-               if err != nil {
-                       ch.Reject(ConnectionFailed, err.Error())
-                       continue
-               }
-               raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort)
-               if err != nil {
-                       ch.Reject(ConnectionFailed, err.Error())
-                       continue
-               }
-
-               if ok := l.forward(*laddr, *raddr, ch); !ok {
+               if ok := l.forward(laddr, raddr, ch); !ok {
                        // Section 7.2, implementations MUST reject spurious incoming
                        // connections.
                        ch.Reject(Prohibited, "no forward for address")
                        continue
                }
+
        }
 }
 
 // remove removes the forward entry, and the channel feeding its
 // listener.
-func (l *forwardList) remove(addr net.TCPAddr) {
+func (l *forwardList) remove(addr net.Addr) {
        l.Lock()
        defer l.Unlock()
        for i, f := range l.entries {
-               if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
+               if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
                        l.entries = append(l.entries[:i], l.entries[i+1:]...)
                        close(f.c)
                        return
@@ -231,12 +264,12 @@ func (l *forwardList) closeAll() {
        l.entries = nil
 }
 
-func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
+func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
        l.Lock()
        defer l.Unlock()
        for _, f := range l.entries {
-               if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
-                       f.c <- forward{ch, &raddr}
+               if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
+                       f.c <- forward{newCh: ch, raddr: raddr}
                        return true
                }
        }
@@ -262,7 +295,7 @@ func (l *tcpListener) Accept() (net.Conn, error) {
        }
        go DiscardRequests(incoming)
 
-       return &tcpChanConn{
+       return &chanConn{
                Channel: ch,
                laddr:   l.laddr,
                raddr:   s.raddr,
@@ -277,7 +310,7 @@ func (l *tcpListener) Close() error {
        }
 
        // this also closes the listener.
-       l.conn.forwards.remove(*l.laddr)
+       l.conn.forwards.remove(l.laddr)
        ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
        if err == nil && !ok {
                err = errors.New("ssh: cancel-tcpip-forward failed")
@@ -293,29 +326,52 @@ func (l *tcpListener) Addr() net.Addr {
 // Dial initiates a connection to the addr from the remote host.
 // The resulting connection has a zero LocalAddr() and RemoteAddr().
 func (c *Client) Dial(n, addr string) (net.Conn, error) {
-       // Parse the address into host and numeric port.
-       host, portString, err := net.SplitHostPort(addr)
-       if err != nil {
-               return nil, err
-       }
-       port, err := strconv.ParseUint(portString, 10, 16)
-       if err != nil {
-               return nil, err
-       }
-       // Use a zero address for local and remote address.
-       zeroAddr := &net.TCPAddr{
-               IP:   net.IPv4zero,
-               Port: 0,
-       }
-       ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
-       if err != nil {
-               return nil, err
+       var ch Channel
+       switch n {
+       case "tcp", "tcp4", "tcp6":
+               // Parse the address into host and numeric port.
+               host, portString, err := net.SplitHostPort(addr)
+               if err != nil {
+                       return nil, err
+               }
+               port, err := strconv.ParseUint(portString, 10, 16)
+               if err != nil {
+                       return nil, err
+               }
+               ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
+               if err != nil {
+                       return nil, err
+               }
+               // Use a zero address for local and remote address.
+               zeroAddr := &net.TCPAddr{
+                       IP:   net.IPv4zero,
+                       Port: 0,
+               }
+               return &chanConn{
+                       Channel: ch,
+                       laddr:   zeroAddr,
+                       raddr:   zeroAddr,
+               }, nil
+       case "unix":
+               var err error
+               ch, err = c.dialStreamLocal(addr)
+               if err != nil {
+                       return nil, err
+               }
+               return &chanConn{
+                       Channel: ch,
+                       laddr: &net.UnixAddr{
+                               Name: "@",
+                               Net:  "unix",
+                       },
+                       raddr: &net.UnixAddr{
+                               Name: addr,
+                               Net:  "unix",
+                       },
+               }, nil
+       default:
+               return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
        }
-       return &tcpChanConn{
-               Channel: ch,
-               laddr:   zeroAddr,
-               raddr:   zeroAddr,
-       }, nil
 }
 
 // DialTCP connects to the remote address raddr on the network net,
@@ -332,7 +388,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
        if err != nil {
                return nil, err
        }
-       return &tcpChanConn{
+       return &chanConn{
                Channel: ch,
                laddr:   laddr,
                raddr:   raddr,
@@ -366,26 +422,26 @@ type tcpChan struct {
        Channel // the backing channel
 }
 
-// tcpChanConn fulfills the net.Conn interface without
+// chanConn fulfills the net.Conn interface without
 // the tcpChan having to hold laddr or raddr directly.
-type tcpChanConn struct {
+type chanConn struct {
        Channel
        laddr, raddr net.Addr
 }
 
 // LocalAddr returns the local network address.
-func (t *tcpChanConn) LocalAddr() net.Addr {
+func (t *chanConn) LocalAddr() net.Addr {
        return t.laddr
 }
 
 // RemoteAddr returns the remote network address.
-func (t *tcpChanConn) RemoteAddr() net.Addr {
+func (t *chanConn) RemoteAddr() net.Addr {
        return t.raddr
 }
 
 // SetDeadline sets the read and write deadlines associated
 // with the connection.
-func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
+func (t *chanConn) SetDeadline(deadline time.Time) error {
        if err := t.SetReadDeadline(deadline); err != nil {
                return err
        }
@@ -396,12 +452,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
 // A zero value for t means Read will not time out.
 // After the deadline, the error from Read will implement net.Error
 // with Timeout() == true.
-func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
+func (t *chanConn) SetReadDeadline(deadline time.Time) error {
+       // for compatibility with previous version,
+       // the error message contains "tcpChan"
        return errors.New("ssh: tcpChan: deadline not supported")
 }
 
 // SetWriteDeadline exists to satisfy the net.Conn interface
 // but is not implemented by this type.  It always returns an error.
-func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
+func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
        return errors.New("ssh: tcpChan: deadline not supported")
 }
index 62fba629ee06003a9fb9099f2b0c5f88633f94a5..ab2b88765a8a8827b88885aaf634d0d3f076dcab 100644 (file)
@@ -8,8 +8,13 @@ import (
        "bufio"
        "errors"
        "io"
+       "log"
 )
 
+// debugTransport if set, will print packet types as they go over the
+// wire. No message decoding is done, to minimize the impact on timing.
+const debugTransport = false
+
 const (
        gcmCipherID    = "aes128-gcm@openssh.com"
        aes128cbcID    = "aes128-cbc"
@@ -22,7 +27,9 @@ type packetConn interface {
        // Encrypt and send a packet of data to the remote peer.
        writePacket(packet []byte) error
 
-       // Read a packet from the connection
+       // Read a packet from the connection. The read is blocking,
+       // i.e. if error is nil, then the returned byte slice is
+       // always non-empty.
        readPacket() ([]byte, error)
 
        // Close closes the write-side of the connection.
@@ -38,7 +45,7 @@ type transport struct {
        bufReader *bufio.Reader
        bufWriter *bufio.Writer
        rand      io.Reader
-
+       isClient  bool
        io.Closer
 }
 
@@ -84,9 +91,38 @@ func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) err
        return nil
 }
 
+func (t *transport) printPacket(p []byte, write bool) {
+       if len(p) == 0 {
+               return
+       }
+       who := "server"
+       if t.isClient {
+               who = "client"
+       }
+       what := "read"
+       if write {
+               what = "write"
+       }
+
+       log.Println(what, who, p[0])
+}
+
 // Read and decrypt next packet.
-func (t *transport) readPacket() ([]byte, error) {
-       return t.reader.readPacket(t.bufReader)
+func (t *transport) readPacket() (p []byte, err error) {
+       for {
+               p, err = t.reader.readPacket(t.bufReader)
+               if err != nil {
+                       break
+               }
+               if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
+                       break
+               }
+       }
+       if debugTransport {
+               t.printPacket(p, false)
+       }
+
+       return p, err
 }
 
 func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
@@ -129,6 +165,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
 }
 
 func (t *transport) writePacket(packet []byte) error {
+       if debugTransport {
+               t.printPacket(packet, true)
+       }
        return t.writer.writePacket(t.bufWriter, t.rand, packet)
 }
 
@@ -169,6 +208,8 @@ func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transp
                },
                Closer: rwc,
        }
+       t.isClient = isClient
+
        if isClient {
                t.reader.dir = serverKeys
                t.writer.dir = clientKeys
@@ -213,7 +254,7 @@ func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (pac
        iv, key, macKey := generateKeys(d, algs, kex)
 
        if algs.Cipher == gcmCipherID {
-               return newGCMCipher(iv, key, macKey)
+               return newGCMCipher(iv, key)
        }
 
        if algs.Cipher == aes128cbcID {
@@ -226,6 +267,7 @@ func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (pac
 
        c := &streamPacketCipher{
                mac: macModes[algs.MAC].new(macKey),
+               etm: macModes[algs.MAC].etm,
        }
        c.macResult = make([]byte, c.mac.Size())
 
index a26f1f8d40a867a3d243116cbfb4a49885dc818d..b5bc882ecf7fdff0b944868f0ae5d9b3e77f3a85 100644 (file)
                        "revisionTime": "2016-09-14T08:04:27Z"
                },
                {
-                       "checksumSHA1": "dwOedwBJ1EIK9+S3t108Bx054Y8=",
+                       "checksumSHA1": "IQkUIOnvlf0tYloFx9mLaXSvXWQ=",
                        "path": "golang.org/x/crypto/curve25519",
-                       "revision": "9477e0b78b9ac3d0b03822fd95422e2fe07627cd",
-                       "revisionTime": "2016-10-31T15:37:30Z"
+                       "revision": "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94",
+                       "revisionTime": "2017-09-21T17:41:56Z"
                },
                {
-                       "checksumSHA1": "wGb//LjBPNxYHqk+dcLo7BjPXK8=",
+                       "checksumSHA1": "1hwn8cgg4EVXhCpJIqmMbzqnUo0=",
                        "path": "golang.org/x/crypto/ed25519",
-                       "revision": "9477e0b78b9ac3d0b03822fd95422e2fe07627cd",
-                       "revisionTime": "2016-10-31T15:37:30Z"
+                       "revision": "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94",
+                       "revisionTime": "2017-09-21T17:41:56Z"
                },
                {
                        "checksumSHA1": "LXFcVx8I587SnWmKycSDEq9yvK8=",
                        "path": "golang.org/x/crypto/ed25519/internal/edwards25519",
-                       "revision": "9477e0b78b9ac3d0b03822fd95422e2fe07627cd",
-                       "revisionTime": "2016-10-31T15:37:30Z"
+                       "revision": "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94",
+                       "revisionTime": "2017-09-21T17:41:56Z"
                },
                {
                        "checksumSHA1": "MCeXr2RNeiG1XG6V+er1OR0qyeo=",
                        "path": "golang.org/x/crypto/md4",
-                       "revision": "ede567c8e044a5913dad1d1af3696d9da953104c",
-                       "revisionTime": "2016-11-04T19:41:44Z"
+                       "revision": "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94",
+                       "revisionTime": "2017-09-21T17:41:56Z"
                },
                {
                        "checksumSHA1": "1MGpGDQqnUoRpv7VEcQrXOBydXE=",
                        "path": "golang.org/x/crypto/pbkdf2",
-                       "revision": "8e06e8ddd9629eb88639aba897641bff8031f1d3",
-                       "revisionTime": "2016-09-10T18:59:01Z"
+                       "revision": "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94",
+                       "revisionTime": "2017-09-21T17:41:56Z"
                },
                {
-                       "checksumSHA1": "LlElMHeTC34ng8eHzjvtUhAgrr8=",
+                       "checksumSHA1": "YXeyyvak2xbvsqj5MBHMzyG+22M=",
                        "path": "golang.org/x/crypto/ssh",
-                       "revision": "9477e0b78b9ac3d0b03822fd95422e2fe07627cd",
-                       "revisionTime": "2016-10-31T15:37:30Z"
+                       "revision": "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94",
+                       "revisionTime": "2017-09-21T17:41:56Z"
                },
                {
                        "checksumSHA1": "9jjO5GjLa0XF/nfWihF02RoH4qc=",