summaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/graceful/server.go48
-rw-r--r--modules/graceful/server_http.go8
-rw-r--r--modules/private/internal.go28
-rw-r--r--modules/proxyprotocol/conn.go506
-rw-r--r--modules/proxyprotocol/errors.go45
-rw-r--r--modules/proxyprotocol/listener.go47
-rw-r--r--modules/proxyprotocol/util.go15
-rw-r--r--modules/setting/setting.go82
-rw-r--r--modules/ssh/ssh_graceful.go2
9 files changed, 734 insertions, 47 deletions
diff --git a/modules/graceful/server.go b/modules/graceful/server.go
index 159a9879df..30a460a943 100644
--- a/modules/graceful/server.go
+++ b/modules/graceful/server.go
@@ -16,6 +16,7 @@ import (
"time"
"code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/proxyprotocol"
"code.gitea.io/gitea/modules/setting"
)
@@ -79,16 +80,27 @@ func NewServer(network, address, name string) *Server {
// ListenAndServe listens on the provided network address and then calls Serve
// to handle requests on incoming connections.
-func (srv *Server) ListenAndServe(serve ServeFunction) error {
+func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error {
go srv.awaitShutdown()
- l, err := GetListener(srv.network, srv.address)
+ listener, err := GetListener(srv.network, srv.address)
if err != nil {
log.Error("Unable to GetListener: %v", err)
return err
}
- srv.listener = newWrappedListener(l, srv)
+ // we need to wrap the listener to take account of our lifecycle
+ listener = newWrappedListener(listener, srv)
+
+ // Now we need to take account of ProxyProtocol settings...
+ if useProxyProtocol {
+ listener = &proxyprotocol.Listener{
+ Listener: listener,
+ ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
+ AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
+ }
+ }
+ srv.listener = listener
srv.BeforeBegin(srv.network, srv.address)
@@ -97,22 +109,44 @@ func (srv *Server) ListenAndServe(serve ServeFunction) error {
// ListenAndServeTLSConfig listens on the provided network address and then calls
// Serve to handle requests on incoming TLS connections.
-func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction) error {
+func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error {
go srv.awaitShutdown()
if tlsConfig.MinVersion == 0 {
tlsConfig.MinVersion = tls.VersionTLS12
}
- l, err := GetListener(srv.network, srv.address)
+ listener, err := GetListener(srv.network, srv.address)
if err != nil {
log.Error("Unable to get Listener: %v", err)
return err
}
- wl := newWrappedListener(l, srv)
- srv.listener = tls.NewListener(wl, tlsConfig)
+ // we need to wrap the listener to take account of our lifecycle
+ listener = newWrappedListener(listener, srv)
+
+ // Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us
+ if useProxyProtocol && !proxyProtocolTLSBridging {
+ listener = &proxyprotocol.Listener{
+ Listener: listener,
+ ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
+ AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
+ }
+ }
+
+ // Now handle the tls protocol
+ listener = tls.NewListener(listener, tlsConfig)
+
+ // Now if we're bridging then we need the proxy to tell us who we're bridging for...
+ if useProxyProtocol && proxyProtocolTLSBridging {
+ listener = &proxyprotocol.Listener{
+ Listener: listener,
+ ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
+ AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
+ }
+ }
+ srv.listener = listener
srv.BeforeBegin(srv.network, srv.address)
return srv.Serve(serve)
diff --git a/modules/graceful/server_http.go b/modules/graceful/server_http.go
index f7b22ceb5e..8ab2bdf41f 100644
--- a/modules/graceful/server_http.go
+++ b/modules/graceful/server_http.go
@@ -28,14 +28,14 @@ func newHTTPServer(network, address, name string, handler http.Handler) (*Server
// HTTPListenAndServe listens on the provided network address and then calls Serve
// to handle requests on incoming connections.
-func HTTPListenAndServe(network, address, name string, handler http.Handler) error {
+func HTTPListenAndServe(network, address, name string, handler http.Handler, useProxyProtocol bool) error {
server, lHandler := newHTTPServer(network, address, name, handler)
- return server.ListenAndServe(lHandler)
+ return server.ListenAndServe(lHandler, useProxyProtocol)
}
// HTTPListenAndServeTLSConfig listens on the provided network address and then calls Serve
// to handle requests on incoming connections.
-func HTTPListenAndServeTLSConfig(network, address, name string, tlsConfig *tls.Config, handler http.Handler) error {
+func HTTPListenAndServeTLSConfig(network, address, name string, tlsConfig *tls.Config, handler http.Handler, useProxyProtocol, proxyProtocolTLSBridging bool) error {
server, lHandler := newHTTPServer(network, address, name, handler)
- return server.ListenAndServeTLSConfig(tlsConfig, lHandler)
+ return server.ListenAndServeTLSConfig(tlsConfig, lHandler, useProxyProtocol, proxyProtocolTLSBridging)
}
diff --git a/modules/private/internal.go b/modules/private/internal.go
index a77a990627..2ea516ba80 100644
--- a/modules/private/internal.go
+++ b/modules/private/internal.go
@@ -14,6 +14,7 @@ import (
"code.gitea.io/gitea/modules/httplib"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/proxyprotocol"
"code.gitea.io/gitea/modules/setting"
)
@@ -50,7 +51,32 @@ func newInternalRequest(ctx context.Context, url, method string) *httplib.Reques
req.SetTransport(&http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
- return d.DialContext(ctx, "unix", setting.HTTPAddr)
+ conn, err := d.DialContext(ctx, "unix", setting.HTTPAddr)
+ if err != nil {
+ return conn, err
+ }
+ if setting.LocalUseProxyProtocol {
+ if err = proxyprotocol.WriteLocalHeader(conn); err != nil {
+ _ = conn.Close()
+ return nil, err
+ }
+ }
+ return conn, err
+ },
+ })
+ } else if setting.LocalUseProxyProtocol {
+ req.SetTransport(&http.Transport{
+ DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
+ var d net.Dialer
+ conn, err := d.DialContext(ctx, network, address)
+ if err != nil {
+ return conn, err
+ }
+ if err = proxyprotocol.WriteLocalHeader(conn); err != nil {
+ _ = conn.Close()
+ return nil, err
+ }
+ return conn, err
},
})
}
diff --git a/modules/proxyprotocol/conn.go b/modules/proxyprotocol/conn.go
new file mode 100644
index 0000000000..10333b204d
--- /dev/null
+++ b/modules/proxyprotocol/conn.go
@@ -0,0 +1,506 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package proxyprotocol
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "code.gitea.io/gitea/modules/log"
+)
+
+var (
+ // v1Prefix is the string we look for at the start of a connection
+ // to check if this connection is using the proxy protocol
+ v1Prefix = []byte("PROXY ")
+ v1PrefixLen = len(v1Prefix)
+ v2Prefix = []byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A")
+ v2PrefixLen = len(v2Prefix)
+)
+
+// Conn is used to wrap and underlying connection which is speaking the
+// Proxy Protocol. RemoteAddr() will return the address of the client
+// instead of the proxy address.
+type Conn struct {
+ bufReader *bufio.Reader
+ conn net.Conn
+ localAddr net.Addr
+ remoteAddr net.Addr
+ once sync.Once
+ proxyHeaderTimeout time.Duration
+ acceptUnknown bool
+}
+
+// NewConn is used to wrap a net.Conn speaking the proxy protocol into
+// a proxyprotocol.Conn
+func NewConn(conn net.Conn, timeout time.Duration) *Conn {
+ pConn := &Conn{
+ bufReader: bufio.NewReader(conn),
+ conn: conn,
+ proxyHeaderTimeout: timeout,
+ }
+ return pConn
+}
+
+// Read reads data from the connection.
+// It will initially read the proxy protocol header.
+// If there is an error parsing the header, it is returned and the socket is closed.
+func (p *Conn) Read(b []byte) (int, error) {
+ if err := p.readProxyHeaderOnce(); err != nil {
+ return 0, err
+ }
+ return p.bufReader.Read(b)
+}
+
+// ReadFrom reads data from a provided reader and copies it to the connection.
+func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
+ if err := p.readProxyHeaderOnce(); err != nil {
+ return 0, err
+ }
+ if rf, ok := p.conn.(io.ReaderFrom); ok {
+ return rf.ReadFrom(r)
+ }
+ return io.Copy(p.conn, r)
+}
+
+// WriteTo reads data from the connection and writes it to the writer.
+// It will initially read the proxy protocol header.
+// If there is an error parsing the header, it is returned and the socket is closed.
+func (p *Conn) WriteTo(w io.Writer) (int64, error) {
+ if err := p.readProxyHeaderOnce(); err != nil {
+ return 0, err
+ }
+ return p.bufReader.WriteTo(w)
+}
+
+// Write writes data to the connection.
+// Write can be made to time out and return an error after a fixed
+// time limit; see SetDeadline and SetWriteDeadline.
+func (p *Conn) Write(b []byte) (int, error) {
+ if err := p.readProxyHeaderOnce(); err != nil {
+ return 0, err
+ }
+ return p.conn.Write(b)
+}
+
+// Close closes the connection.
+// Any blocked Read or Write operations will be unblocked and return errors.
+func (p *Conn) Close() error {
+ return p.conn.Close()
+}
+
+// LocalAddr returns the local network address.
+func (p *Conn) LocalAddr() net.Addr {
+ _ = p.readProxyHeaderOnce()
+ if p.localAddr != nil {
+ return p.localAddr
+ }
+ return p.conn.LocalAddr()
+}
+
+// RemoteAddr returns the address of the client if the proxy
+// protocol is being used, otherwise just returns the address of
+// the socket peer. If there is an error parsing the header, the
+// address of the client is not returned, and the socket is closed.
+// One implication of this is that the call could block if the
+// client is slow. Using a Deadline is recommended if this is called
+// before Read()
+func (p *Conn) RemoteAddr() net.Addr {
+ _ = p.readProxyHeaderOnce()
+ if p.remoteAddr != nil {
+ return p.remoteAddr
+ }
+ return p.conn.RemoteAddr()
+}
+
+// SetDeadline sets the read and write deadlines associated
+// with the connection. It is equivalent to calling both
+// SetReadDeadline and SetWriteDeadline.
+//
+// A deadline is an absolute time after which I/O operations
+// fail instead of blocking. The deadline applies to all future
+// and pending I/O, not just the immediately following call to
+// Read or Write. After a deadline has been exceeded, the
+// connection can be refreshed by setting a deadline in the future.
+//
+// If the deadline is exceeded a call to Read or Write or to other
+// I/O methods will return an error that wraps os.ErrDeadlineExceeded.
+// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
+// The error's Timeout method will return true, but note that there
+// are other possible errors for which the Timeout method will
+// return true even if the deadline has not been exceeded.
+//
+// An idle timeout can be implemented by repeatedly extending
+// the deadline after successful Read or Write calls.
+//
+// A zero value for t means I/O operations will not time out.
+func (p *Conn) SetDeadline(t time.Time) error {
+ return p.conn.SetDeadline(t)
+}
+
+// SetReadDeadline sets the deadline for future Read calls
+// and any currently-blocked Read call.
+// A zero value for t means Read will not time out.
+func (p *Conn) SetReadDeadline(t time.Time) error {
+ return p.conn.SetReadDeadline(t)
+}
+
+// SetWriteDeadline sets the deadline for future Write calls
+// and any currently-blocked Write call.
+// Even if write times out, it may return n > 0, indicating that
+// some of the data was successfully written.
+// A zero value for t means Write will not time out.
+func (p *Conn) SetWriteDeadline(t time.Time) error {
+ return p.conn.SetWriteDeadline(t)
+}
+
+// readProxyHeaderOnce will ensure that the proxy header has been read
+func (p *Conn) readProxyHeaderOnce() (err error) {
+ p.once.Do(func() {
+ if err = p.readProxyHeader(); err != nil && err != io.EOF {
+ log.Error("Failed to read proxy prefix: %v", err)
+ p.Close()
+ p.bufReader = bufio.NewReader(p.conn)
+ }
+ })
+ return err
+}
+
+func (p *Conn) readProxyHeader() error {
+ if p.proxyHeaderTimeout != 0 {
+ readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
+ _ = p.conn.SetReadDeadline(readDeadLine)
+ defer func() {
+ _ = p.conn.SetReadDeadline(time.Time{})
+ }()
+ }
+
+ inp, err := p.bufReader.Peek(v1PrefixLen)
+ if err != nil {
+ return err
+ }
+
+ if bytes.Equal(inp, v1Prefix) {
+ return p.readV1ProxyHeader()
+ }
+
+ inp, err = p.bufReader.Peek(v2PrefixLen)
+ if err != nil {
+ return err
+ }
+ if bytes.Equal(inp, v2Prefix) {
+ return p.readV2ProxyHeader()
+ }
+
+ return &ErrBadHeader{inp}
+}
+
+func (p *Conn) readV2ProxyHeader() error {
+ // The binary header format starts with a constant 12 bytes block containing the
+ // protocol signature :
+ //
+ // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
+ //
+ // Note that this block contains a null byte at the 5th position, so it must not
+ // be handled as a null-terminated string.
+
+ if _, err := p.bufReader.Discard(v2PrefixLen); err != nil {
+ // This shouldn't happen as we have already asserted that there should be enough in the buffer
+ return err
+ }
+
+ // The next byte (the 13th one) is the protocol version and command.
+ version, err := p.bufReader.ReadByte()
+ if err != nil {
+ return err
+ }
+
+ // The 14th byte contains the transport protocol and address family.otocol.
+ familyByte, err := p.bufReader.ReadByte()
+ if err != nil {
+ return err
+ }
+
+ // The 15th and 16th bytes is the address length in bytes in network endian order.
+ var addressLen uint16
+ if err := binary.Read(p.bufReader, binary.BigEndian, &addressLen); err != nil {
+ return err
+ }
+
+ // Now handle the version byte: (14th byte).
+ // The highest four bits contains the version. As of this specification, it must
+ // always be sent as \x2 and the receiver must only accept this value.
+ if version>>4 != 0x2 {
+ return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
+ }
+
+ // The lowest four bits represents the command :
+ switch version & 0xf {
+ case 0x0:
+ // - \x0 : LOCAL : the connection was established on purpose by the proxy
+ // without being relayed. The connection endpoints are the sender and the
+ // receiver. Such connections exist when the proxy sends health-checks to the
+ // server. The receiver must accept this connection as valid and must use the
+ // real connection endpoints and discard the protocol block including the
+ // family which is ignored.
+
+ // We therefore ignore the 14th, 15th and 16th bytes
+ p.remoteAddr = p.conn.LocalAddr()
+ p.localAddr = p.conn.RemoteAddr()
+ return nil
+ case 0x1:
+ // - \x1 : PROXY : the connection was established on behalf of another node,
+ // and reflects the original connection endpoints. The receiver must then use
+ // the information provided in the protocol block to get original the address.
+ default:
+ // - other values are unassigned and must not be emitted by senders. Receivers
+ // must drop connections presenting unexpected values here.
+ return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
+ }
+
+ // Now handle the familyByte byte: (15th byte).
+ // The highest 4 bits contain the address family, the lowest 4 bits contain the protocol
+
+ // The address family maps to the original socket family without necessarily
+ // matching the values internally used by the system. It may be one of :
+ //
+ // - 0x0 : AF_UNSPEC : the connection is forwarded for an unknown, unspecified
+ // or unsupported protocol. The sender should use this family when sending
+ // LOCAL commands or when dealing with unsupported protocol families. The
+ // receiver is free to accept the connection anyway and use the real endpoint
+ // addresses or to reject it. The receiver should ignore address information.
+ //
+ // - 0x1 : AF_INET : the forwarded connection uses the AF_INET address family
+ // (IPv4). The addresses are exactly 4 bytes each in network byte order,
+ // followed by transport protocol information (typically ports).
+ //
+ // - 0x2 : AF_INET6 : the forwarded connection uses the AF_INET6 address family
+ // (IPv6). The addresses are exactly 16 bytes each in network byte order,
+ // followed by transport protocol information (typically ports).
+ //
+ // - 0x3 : AF_UNIX : the forwarded connection uses the AF_UNIX address family
+ // (UNIX). The addresses are exactly 108 bytes each.
+ //
+ // - other values are unspecified and must not be emitted in version 2 of this
+ // protocol and must be rejected as invalid by receivers.
+
+ // The transport protocol is specified in the lowest 4 bits of the 14th byte :
+ //
+ // - 0x0 : UNSPEC : the connection is forwarded for an unknown, unspecified
+ // or unsupported protocol. The sender should use this family when sending
+ // LOCAL commands or when dealing with unsupported protocol families. The
+ // receiver is free to accept the connection anyway and use the real endpoint
+ // addresses or to reject it. The receiver should ignore address information.
+ //
+ // - 0x1 : STREAM : the forwarded connection uses a SOCK_STREAM protocol (eg:
+ // TCP or UNIX_STREAM). When used with AF_INET/AF_INET6 (TCP), the addresses
+ // are followed by the source and destination ports represented on 2 bytes
+ // each in network byte order.
+ //
+ // - 0x2 : DGRAM : the forwarded connection uses a SOCK_DGRAM protocol (eg:
+ // UDP or UNIX_DGRAM). When used with AF_INET/AF_INET6 (UDP), the addresses
+ // are followed by the source and destination ports represented on 2 bytes
+ // each in network byte order.
+ //
+ // - other values are unspecified and must not be emitted in version 2 of this
+ // protocol and must be rejected as invalid by receivers.
+
+ if familyByte>>4 == 0x0 || familyByte&0xf == 0x0 {
+ // - hi 0x0 : AF_UNSPEC : the connection is forwarded for an unknown address type
+ // or
+ // - lo 0x0 : UNSPEC : the connection is forwarded for an unspecified protocol
+ if !p.acceptUnknown {
+ p.conn.Close()
+ return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
+ }
+ p.remoteAddr = p.conn.LocalAddr()
+ p.localAddr = p.conn.RemoteAddr()
+ _, err = p.bufReader.Discard(int(addressLen))
+ return err
+ }
+
+ // other address or protocol
+ if (familyByte>>4) > 0x3 || (familyByte&0xf) > 0x2 {
+ return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
+ }
+
+ // Handle AF_UNIX addresses
+ if familyByte>>4 == 0x3 {
+ // - \x31 : UNIX stream : the forwarded connection uses SOCK_STREAM over the
+ // AF_UNIX protocol family. Address length is 2*108 = 216 bytes.
+ // - \x32 : UNIX datagram : the forwarded connection uses SOCK_DGRAM over the
+ // AF_UNIX protocol family. Address length is 2*108 = 216 bytes.
+ if addressLen != 216 {
+ return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
+ }
+ remoteName := make([]byte, 108)
+ localName := make([]byte, 108)
+ if _, err := p.bufReader.Read(remoteName); err != nil {
+ return err
+ }
+ if _, err := p.bufReader.Read(localName); err != nil {
+ return err
+ }
+ protocol := "unix"
+ if familyByte&0xf == 2 {
+ protocol = "unixgram"
+ }
+
+ p.remoteAddr = &net.UnixAddr{
+ Name: string(remoteName),
+ Net: protocol,
+ }
+ p.localAddr = &net.UnixAddr{
+ Name: string(localName),
+ Net: protocol,
+ }
+ return nil
+ }
+
+ var remoteIP []byte
+ var localIP []byte
+ var remotePort uint16
+ var localPort uint16
+
+ if familyByte>>4 == 0x1 {
+ // AF_INET
+ // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
+ // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
+ // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
+ // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
+ if addressLen != 12 {
+ return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
+ }
+
+ remoteIP = make([]byte, 4)
+ localIP = make([]byte, 4)
+ } else {
+ // AF_INET6
+ // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
+ // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
+ // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
+ // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
+ if addressLen != 36 {
+ return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))}
+ }
+
+ remoteIP = make([]byte, 16)
+ localIP = make([]byte, 16)
+ }
+
+ if _, err := p.bufReader.Read(remoteIP); err != nil {
+ return err
+ }
+ if _, err := p.bufReader.Read(localIP); err != nil {
+ return err
+ }
+ if err := binary.Read(p.bufReader, binary.BigEndian, &remotePort); err != nil {
+ return err
+ }
+ if err := binary.Read(p.bufReader, binary.BigEndian, &localPort); err != nil {
+ return err
+ }
+
+ if familyByte&0xf == 1 {
+ p.remoteAddr = &net.TCPAddr{
+ IP: remoteIP,
+ Port: int(remotePort),
+ }
+ p.localAddr = &net.TCPAddr{
+ IP: localIP,
+ Port: int(localPort),
+ }
+ } else {
+ p.remoteAddr = &net.UDPAddr{
+ IP: remoteIP,
+ Port: int(remotePort),
+ }
+ p.localAddr = &net.UDPAddr{
+ IP: localIP,
+ Port: int(localPort),
+ }
+ }
+ return nil
+}
+
+func (p *Conn) readV1ProxyHeader() error {
+ // Read until a newline
+ header, err := p.bufReader.ReadString('\n')
+ if err != nil {
+ p.conn.Close()
+ return err
+ }
+
+ if header[len(header)-2] != '\r' {
+ return &ErrBadHeader{[]byte(header)}
+ }
+
+ // Strip the carriage return and new line
+ header = header[:len(header)-2]
+
+ // Split on spaces, should be (PROXY <type> <remote addr> <local addr> <remote port> <local port>)
+ parts := strings.Split(header, " ")
+ if len(parts) < 2 {
+ p.conn.Close()
+ return &ErrBadHeader{[]byte(header)}
+ }
+
+ // Verify the type is known
+ switch parts[1] {
+ case "UNKNOWN":
+ if !p.acceptUnknown || len(parts) != 2 {
+ p.conn.Close()
+ return &ErrBadHeader{[]byte(header)}
+ }
+ p.remoteAddr = p.conn.LocalAddr()
+ p.localAddr = p.conn.RemoteAddr()
+ return nil
+ case "TCP4":
+ case "TCP6":
+ default:
+ p.conn.Close()
+ return &ErrBadAddressType{parts[1]}
+ }
+
+ if len(parts) != 6 {
+ p.conn.Close()
+ return &ErrBadHeader{[]byte(header)}
+ }
+
+ // Parse out the remote address
+ ip := net.ParseIP(parts[2])
+ if ip == nil {
+ p.conn.Close()
+ return &ErrBadRemote{parts[2], parts[4]}
+ }
+ port, err := strconv.Atoi(parts[4])
+ if err != nil {
+ p.conn.Close()
+ return &ErrBadRemote{parts[2], parts[4]}
+ }
+ p.remoteAddr = &net.TCPAddr{IP: ip, Port: port}
+
+ // Parse out the destination address
+ ip = net.ParseIP(parts[3])
+ if ip == nil {
+ p.conn.Close()
+ return &ErrBadLocal{parts[3], parts[5]}
+ }
+ port, err = strconv.Atoi(parts[5])
+ if err != nil {
+ p.conn.Close()
+ return &ErrBadLocal{parts[3], parts[5]}
+ }
+ p.localAddr = &net.TCPAddr{IP: ip, Port: port}
+
+ return nil
+}
diff --git a/modules/proxyprotocol/errors.go b/modules/proxyprotocol/errors.go
new file mode 100644
index 0000000000..2acf9d84b0
--- /dev/null
+++ b/modules/proxyprotocol/errors.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package proxyprotocol
+
+import "fmt"
+
+// ErrBadHeader is an error demonstrating a bad proxy header
+type ErrBadHeader struct {
+ Header []byte
+}
+
+func (e *ErrBadHeader) Error() string {
+ return fmt.Sprintf("Unexpected proxy header: %v", e.Header)
+}
+
+// ErrBadAddressType is an error demonstrating a bad proxy header with bad Address type
+type ErrBadAddressType struct {
+ Address string
+}
+
+func (e *ErrBadAddressType) Error() string {
+ return fmt.Sprintf("Unexpected proxy header address type: %s", e.Address)
+}
+
+// ErrBadRemote is an error demonstrating a bad proxy header with bad Remote
+type ErrBadRemote struct {
+ IP string
+ Port string
+}
+
+func (e *ErrBadRemote) Error() string {
+ return fmt.Sprintf("Unexpected proxy header remote IP and port: %s %s", e.IP, e.Port)
+}
+
+// ErrBadLocal is an error demonstrating a bad proxy header with bad Local
+type ErrBadLocal struct {
+ IP string
+ Port string
+}
+
+func (e *ErrBadLocal) Error() string {
+ return fmt.Sprintf("Unexpected proxy header local IP and port: %s %s", e.IP, e.Port)
+}
diff --git a/modules/proxyprotocol/listener.go b/modules/proxyprotocol/listener.go
new file mode 100644
index 0000000000..64d9b323e5
--- /dev/null
+++ b/modules/proxyprotocol/listener.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package proxyprotocol
+
+import (
+ "net"
+ "time"
+)
+
+// Listener is used to wrap an underlying listener,
+// whose connections may be using the HAProxy Proxy Protocol (version 1 or 2).
+// If the connection is using the protocol, the RemoteAddr() will return
+// the correct client address.
+//
+// Optionally define ProxyHeaderTimeout to set a maximum time to
+// receive the Proxy Protocol Header. Zero means no timeout.
+type Listener struct {
+ Listener net.Listener
+ ProxyHeaderTimeout time.Duration
+ AcceptUnknown bool // allow PROXY UNKNOWN
+}
+
+// Accept implements the Accept method in the Listener interface
+// it waits for the next call and returns a wrapped Conn.
+func (p *Listener) Accept() (net.Conn, error) {
+ // Get the underlying connection
+ conn, err := p.Listener.Accept()
+ if err != nil {
+ return nil, err
+ }
+
+ newConn := NewConn(conn, p.ProxyHeaderTimeout)
+ newConn.acceptUnknown = p.AcceptUnknown
+ return newConn, nil
+}
+
+// Close closes the underlying listener.
+func (p *Listener) Close() error {
+ return p.Listener.Close()
+}
+
+// Addr returns the underlying listener's network address.
+func (p *Listener) Addr() net.Addr {
+ return p.Listener.Addr()
+}
diff --git a/modules/proxyprotocol/util.go b/modules/proxyprotocol/util.go
new file mode 100644
index 0000000000..b12771b686
--- /dev/null
+++ b/modules/proxyprotocol/util.go
@@ -0,0 +1,15 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package proxyprotocol
+
+import "io"
+
+var localHeader = append(v2Prefix, '\x20', '\x00', '\x00', '\x00', '\x00')
+
+// WriteLocalHeader will write the ProxyProtocol Header for a local connection to the provided writer
+func WriteLocalHeader(w io.Writer) error {
+ _, err := w.Write(localHeader)
+ return err
+}
diff --git a/modules/setting/setting.go b/modules/setting/setting.go
index 3ab25fef6e..931b6523ea 100644
--- a/modules/setting/setting.go
+++ b/modules/setting/setting.go
@@ -94,45 +94,52 @@ var (
LocalURL string
// Server settings
- Protocol Scheme
- Domain string
- HTTPAddr string
- HTTPPort string
- RedirectOtherPort bool
- PortToRedirect string
- OfflineMode bool
- CertFile string
- KeyFile string
- StaticRootPath string
- StaticCacheTime time.Duration
- EnableGzip bool
- LandingPageURL LandingPage
- LandingPageCustom string
- UnixSocketPermission uint32
- EnablePprof bool
- PprofDataPath string
- EnableAcme bool
- AcmeTOS bool
- AcmeLiveDirectory string
- AcmeEmail string
- AcmeURL string
- AcmeCARoot string
- SSLMinimumVersion string
- SSLMaximumVersion string
- SSLCurvePreferences []string
- SSLCipherSuites []string
- GracefulRestartable bool
- GracefulHammerTime time.Duration
- StartupTimeout time.Duration
- PerWriteTimeout = 30 * time.Second
- PerWritePerKbTimeout = 10 * time.Second
- StaticURLPrefix string
- AbsoluteAssetURL string
+ Protocol Scheme
+ UseProxyProtocol bool // `ini:"USE_PROXY_PROTOCOL"`
+ ProxyProtocolTLSBridging bool //`ini:"PROXY_PROTOCOL_TLS_BRIDGING"`
+ ProxyProtocolHeaderTimeout time.Duration
+ ProxyProtocolAcceptUnknown bool
+ Domain string
+ HTTPAddr string
+ HTTPPort string
+ LocalUseProxyProtocol bool
+ RedirectOtherPort bool
+ RedirectorUseProxyProtocol bool
+ PortToRedirect string
+ OfflineMode bool
+ CertFile string
+ KeyFile string
+ StaticRootPath string
+ StaticCacheTime time.Duration
+ EnableGzip bool
+ LandingPageURL LandingPage
+ LandingPageCustom string
+ UnixSocketPermission uint32
+ EnablePprof bool
+ PprofDataPath string
+ EnableAcme bool
+ AcmeTOS bool
+ AcmeLiveDirectory string
+ AcmeEmail string
+ AcmeURL string
+ AcmeCARoot string
+ SSLMinimumVersion string
+ SSLMaximumVersion string
+ SSLCurvePreferences []string
+ SSLCipherSuites []string
+ GracefulRestartable bool
+ GracefulHammerTime time.Duration
+ StartupTimeout time.Duration
+ PerWriteTimeout = 30 * time.Second
+ PerWritePerKbTimeout = 10 * time.Second
+ StaticURLPrefix string
+ AbsoluteAssetURL string
SSH = struct {
Disabled bool `ini:"DISABLE_SSH"`
StartBuiltinServer bool `ini:"START_SSH_SERVER"`
BuiltinServerUser string `ini:"BUILTIN_SSH_SERVER_USER"`
+ UseProxyProtocol bool `ini:"SSH_SERVER_USE_PROXY_PROTOCOL"`
Domain string `ini:"SSH_DOMAIN"`
Port int `ini:"SSH_PORT"`
User string `ini:"SSH_USER"`
@@ -717,6 +724,10 @@ func loadFromConf(allowEmpty bool, extraConfig string) {
HTTPAddr = filepath.Join(AppWorkPath, HTTPAddr)
}
}
+ UseProxyProtocol = sec.Key("USE_PROXY_PROTOCOL").MustBool(false)
+ ProxyProtocolTLSBridging = sec.Key("PROXY_PROTOCOL_TLS_BRIDGING").MustBool(false)
+ ProxyProtocolHeaderTimeout = sec.Key("PROXY_PROTOCOL_HEADER_TIMEOUT").MustDuration(5 * time.Second)
+ ProxyProtocolAcceptUnknown = sec.Key("PROXY_PROTOCOL_ACCEPT_UNKNOWN").MustBool(false)
GracefulRestartable = sec.Key("ALLOW_GRACEFUL_RESTARTS").MustBool(true)
GracefulHammerTime = sec.Key("GRACEFUL_HAMMER_TIME").MustDuration(60 * time.Second)
StartupTimeout = sec.Key("STARTUP_TIMEOUT").MustDuration(0 * time.Second)
@@ -770,8 +781,10 @@ func loadFromConf(allowEmpty bool, extraConfig string) {
}
LocalURL = sec.Key("LOCAL_ROOT_URL").MustString(defaultLocalURL)
LocalURL = strings.TrimRight(LocalURL, "/") + "/"
+ LocalUseProxyProtocol = sec.Key("LOCAL_USE_PROXY_PROTOCOL").MustBool(UseProxyProtocol)
RedirectOtherPort = sec.Key("REDIRECT_OTHER_PORT").MustBool(false)
PortToRedirect = sec.Key("PORT_TO_REDIRECT").MustString("80")
+ RedirectorUseProxyProtocol = sec.Key("REDIRECTOR_USE_PROXY_PROTOCOL").MustBool(UseProxyProtocol)
OfflineMode = sec.Key("OFFLINE_MODE").MustBool()
DisableRouterLog = sec.Key("DISABLE_ROUTER_LOG").MustBool()
if len(StaticRootPath) == 0 {
@@ -836,6 +849,7 @@ func loadFromConf(allowEmpty bool, extraConfig string) {
SSH.KeygenPath = sec.Key("SSH_KEYGEN_PATH").MustString("ssh-keygen")
SSH.Port = sec.Key("SSH_PORT").MustInt(22)
SSH.ListenPort = sec.Key("SSH_LISTEN_PORT").MustInt(SSH.Port)
+ SSH.UseProxyProtocol = sec.Key("SSH_SERVER_USE_PROXY_PROTOCOL").MustBool(false)
// When disable SSH, start builtin server value is ignored.
if SSH.Disabled {
diff --git a/modules/ssh/ssh_graceful.go b/modules/ssh/ssh_graceful.go
index 9b91baf09e..166ea0b982 100644
--- a/modules/ssh/ssh_graceful.go
+++ b/modules/ssh/ssh_graceful.go
@@ -17,7 +17,7 @@ func listen(server *ssh.Server) {
gracefulServer.PerWriteTimeout = setting.SSH.PerWriteTimeout
gracefulServer.PerWritePerKbTimeout = setting.SSH.PerWritePerKbTimeout
- err := gracefulServer.ListenAndServe(server.Serve)
+ err := gracefulServer.ListenAndServe(server.Serve, setting.SSH.UseProxyProtocol)
if err != nil {
select {
case <-graceful.GetManager().IsShutdown():