diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/graceful/server.go | 48 | ||||
-rw-r--r-- | modules/graceful/server_http.go | 8 | ||||
-rw-r--r-- | modules/private/internal.go | 28 | ||||
-rw-r--r-- | modules/proxyprotocol/conn.go | 506 | ||||
-rw-r--r-- | modules/proxyprotocol/errors.go | 45 | ||||
-rw-r--r-- | modules/proxyprotocol/listener.go | 47 | ||||
-rw-r--r-- | modules/proxyprotocol/util.go | 15 | ||||
-rw-r--r-- | modules/setting/setting.go | 82 | ||||
-rw-r--r-- | modules/ssh/ssh_graceful.go | 2 |
9 files changed, 734 insertions, 47 deletions
diff --git a/modules/graceful/server.go b/modules/graceful/server.go index 159a9879df..30a460a943 100644 --- a/modules/graceful/server.go +++ b/modules/graceful/server.go @@ -16,6 +16,7 @@ import ( "time" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/proxyprotocol" "code.gitea.io/gitea/modules/setting" ) @@ -79,16 +80,27 @@ func NewServer(network, address, name string) *Server { // ListenAndServe listens on the provided network address and then calls Serve // to handle requests on incoming connections. -func (srv *Server) ListenAndServe(serve ServeFunction) error { +func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error { go srv.awaitShutdown() - l, err := GetListener(srv.network, srv.address) + listener, err := GetListener(srv.network, srv.address) if err != nil { log.Error("Unable to GetListener: %v", err) return err } - srv.listener = newWrappedListener(l, srv) + // we need to wrap the listener to take account of our lifecycle + listener = newWrappedListener(listener, srv) + + // Now we need to take account of ProxyProtocol settings... + if useProxyProtocol { + listener = &proxyprotocol.Listener{ + Listener: listener, + ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, + AcceptUnknown: setting.ProxyProtocolAcceptUnknown, + } + } + srv.listener = listener srv.BeforeBegin(srv.network, srv.address) @@ -97,22 +109,44 @@ func (srv *Server) ListenAndServe(serve ServeFunction) error { // ListenAndServeTLSConfig listens on the provided network address and then calls // Serve to handle requests on incoming TLS connections. -func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction) error { +func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error { go srv.awaitShutdown() if tlsConfig.MinVersion == 0 { tlsConfig.MinVersion = tls.VersionTLS12 } - l, err := GetListener(srv.network, srv.address) + listener, err := GetListener(srv.network, srv.address) if err != nil { log.Error("Unable to get Listener: %v", err) return err } - wl := newWrappedListener(l, srv) - srv.listener = tls.NewListener(wl, tlsConfig) + // we need to wrap the listener to take account of our lifecycle + listener = newWrappedListener(listener, srv) + + // Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us + if useProxyProtocol && !proxyProtocolTLSBridging { + listener = &proxyprotocol.Listener{ + Listener: listener, + ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, + AcceptUnknown: setting.ProxyProtocolAcceptUnknown, + } + } + + // Now handle the tls protocol + listener = tls.NewListener(listener, tlsConfig) + + // Now if we're bridging then we need the proxy to tell us who we're bridging for... + if useProxyProtocol && proxyProtocolTLSBridging { + listener = &proxyprotocol.Listener{ + Listener: listener, + ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, + AcceptUnknown: setting.ProxyProtocolAcceptUnknown, + } + } + srv.listener = listener srv.BeforeBegin(srv.network, srv.address) return srv.Serve(serve) diff --git a/modules/graceful/server_http.go b/modules/graceful/server_http.go index f7b22ceb5e..8ab2bdf41f 100644 --- a/modules/graceful/server_http.go +++ b/modules/graceful/server_http.go @@ -28,14 +28,14 @@ func newHTTPServer(network, address, name string, handler http.Handler) (*Server // HTTPListenAndServe listens on the provided network address and then calls Serve // to handle requests on incoming connections. -func HTTPListenAndServe(network, address, name string, handler http.Handler) error { +func HTTPListenAndServe(network, address, name string, handler http.Handler, useProxyProtocol bool) error { server, lHandler := newHTTPServer(network, address, name, handler) - return server.ListenAndServe(lHandler) + return server.ListenAndServe(lHandler, useProxyProtocol) } // HTTPListenAndServeTLSConfig listens on the provided network address and then calls Serve // to handle requests on incoming connections. -func HTTPListenAndServeTLSConfig(network, address, name string, tlsConfig *tls.Config, handler http.Handler) error { +func HTTPListenAndServeTLSConfig(network, address, name string, tlsConfig *tls.Config, handler http.Handler, useProxyProtocol, proxyProtocolTLSBridging bool) error { server, lHandler := newHTTPServer(network, address, name, handler) - return server.ListenAndServeTLSConfig(tlsConfig, lHandler) + return server.ListenAndServeTLSConfig(tlsConfig, lHandler, useProxyProtocol, proxyProtocolTLSBridging) } diff --git a/modules/private/internal.go b/modules/private/internal.go index a77a990627..2ea516ba80 100644 --- a/modules/private/internal.go +++ b/modules/private/internal.go @@ -14,6 +14,7 @@ import ( "code.gitea.io/gitea/modules/httplib" "code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/proxyprotocol" "code.gitea.io/gitea/modules/setting" ) @@ -50,7 +51,32 @@ func newInternalRequest(ctx context.Context, url, method string) *httplib.Reques req.SetTransport(&http.Transport{ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { var d net.Dialer - return d.DialContext(ctx, "unix", setting.HTTPAddr) + conn, err := d.DialContext(ctx, "unix", setting.HTTPAddr) + if err != nil { + return conn, err + } + if setting.LocalUseProxyProtocol { + if err = proxyprotocol.WriteLocalHeader(conn); err != nil { + _ = conn.Close() + return nil, err + } + } + return conn, err + }, + }) + } else if setting.LocalUseProxyProtocol { + req.SetTransport(&http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, address) + if err != nil { + return conn, err + } + if err = proxyprotocol.WriteLocalHeader(conn); err != nil { + _ = conn.Close() + return nil, err + } + return conn, err }, }) } diff --git a/modules/proxyprotocol/conn.go b/modules/proxyprotocol/conn.go new file mode 100644 index 0000000000..10333b204d --- /dev/null +++ b/modules/proxyprotocol/conn.go @@ -0,0 +1,506 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package proxyprotocol + +import ( + "bufio" + "bytes" + "encoding/binary" + "io" + "net" + "strconv" + "strings" + "sync" + "time" + + "code.gitea.io/gitea/modules/log" +) + +var ( + // v1Prefix is the string we look for at the start of a connection + // to check if this connection is using the proxy protocol + v1Prefix = []byte("PROXY ") + v1PrefixLen = len(v1Prefix) + v2Prefix = []byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A") + v2PrefixLen = len(v2Prefix) +) + +// Conn is used to wrap and underlying connection which is speaking the +// Proxy Protocol. RemoteAddr() will return the address of the client +// instead of the proxy address. +type Conn struct { + bufReader *bufio.Reader + conn net.Conn + localAddr net.Addr + remoteAddr net.Addr + once sync.Once + proxyHeaderTimeout time.Duration + acceptUnknown bool +} + +// NewConn is used to wrap a net.Conn speaking the proxy protocol into +// a proxyprotocol.Conn +func NewConn(conn net.Conn, timeout time.Duration) *Conn { + pConn := &Conn{ + bufReader: bufio.NewReader(conn), + conn: conn, + proxyHeaderTimeout: timeout, + } + return pConn +} + +// Read reads data from the connection. +// It will initially read the proxy protocol header. +// If there is an error parsing the header, it is returned and the socket is closed. +func (p *Conn) Read(b []byte) (int, error) { + if err := p.readProxyHeaderOnce(); err != nil { + return 0, err + } + return p.bufReader.Read(b) +} + +// ReadFrom reads data from a provided reader and copies it to the connection. +func (p *Conn) ReadFrom(r io.Reader) (int64, error) { + if err := p.readProxyHeaderOnce(); err != nil { + return 0, err + } + if rf, ok := p.conn.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + return io.Copy(p.conn, r) +} + +// WriteTo reads data from the connection and writes it to the writer. +// It will initially read the proxy protocol header. +// If there is an error parsing the header, it is returned and the socket is closed. +func (p *Conn) WriteTo(w io.Writer) (int64, error) { + if err := p.readProxyHeaderOnce(); err != nil { + return 0, err + } + return p.bufReader.WriteTo(w) +} + +// Write writes data to the connection. +// Write can be made to time out and return an error after a fixed +// time limit; see SetDeadline and SetWriteDeadline. +func (p *Conn) Write(b []byte) (int, error) { + if err := p.readProxyHeaderOnce(); err != nil { + return 0, err + } + return p.conn.Write(b) +} + +// Close closes the connection. +// Any blocked Read or Write operations will be unblocked and return errors. +func (p *Conn) Close() error { + return p.conn.Close() +} + +// LocalAddr returns the local network address. +func (p *Conn) LocalAddr() net.Addr { + _ = p.readProxyHeaderOnce() + if p.localAddr != nil { + return p.localAddr + } + return p.conn.LocalAddr() +} + +// RemoteAddr returns the address of the client if the proxy +// protocol is being used, otherwise just returns the address of +// the socket peer. If there is an error parsing the header, the +// address of the client is not returned, and the socket is closed. +// One implication of this is that the call could block if the +// client is slow. Using a Deadline is recommended if this is called +// before Read() +func (p *Conn) RemoteAddr() net.Addr { + _ = p.readProxyHeaderOnce() + if p.remoteAddr != nil { + return p.remoteAddr + } + return p.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. It is equivalent to calling both +// SetReadDeadline and SetWriteDeadline. +// +// A deadline is an absolute time after which I/O operations +// fail instead of blocking. The deadline applies to all future +// and pending I/O, not just the immediately following call to +// Read or Write. After a deadline has been exceeded, the +// connection can be refreshed by setting a deadline in the future. +// +// If the deadline is exceeded a call to Read or Write or to other +// I/O methods will return an error that wraps os.ErrDeadlineExceeded. +// This can be tested using errors.Is(err, os.ErrDeadlineExceeded). +// The error's Timeout method will return true, but note that there +// are other possible errors for which the Timeout method will +// return true even if the deadline has not been exceeded. +// +// An idle timeout can be implemented by repeatedly extending +// the deadline after successful Read or Write calls. +// +// A zero value for t means I/O operations will not time out. +func (p *Conn) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +// SetReadDeadline sets the deadline for future Read calls +// and any currently-blocked Read call. +// A zero value for t means Read will not time out. +func (p *Conn) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the deadline for future Write calls +// and any currently-blocked Write call. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means Write will not time out. +func (p *Conn) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} + +// readProxyHeaderOnce will ensure that the proxy header has been read +func (p *Conn) readProxyHeaderOnce() (err error) { + p.once.Do(func() { + if err = p.readProxyHeader(); err != nil && err != io.EOF { + log.Error("Failed to read proxy prefix: %v", err) + p.Close() + p.bufReader = bufio.NewReader(p.conn) + } + }) + return err +} + +func (p *Conn) readProxyHeader() error { + if p.proxyHeaderTimeout != 0 { + readDeadLine := time.Now().Add(p.proxyHeaderTimeout) + _ = p.conn.SetReadDeadline(readDeadLine) + defer func() { + _ = p.conn.SetReadDeadline(time.Time{}) + }() + } + + inp, err := p.bufReader.Peek(v1PrefixLen) + if err != nil { + return err + } + + if bytes.Equal(inp, v1Prefix) { + return p.readV1ProxyHeader() + } + + inp, err = p.bufReader.Peek(v2PrefixLen) + if err != nil { + return err + } + if bytes.Equal(inp, v2Prefix) { + return p.readV2ProxyHeader() + } + + return &ErrBadHeader{inp} +} + +func (p *Conn) readV2ProxyHeader() error { + // The binary header format starts with a constant 12 bytes block containing the + // protocol signature : + // + // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A + // + // Note that this block contains a null byte at the 5th position, so it must not + // be handled as a null-terminated string. + + if _, err := p.bufReader.Discard(v2PrefixLen); err != nil { + // This shouldn't happen as we have already asserted that there should be enough in the buffer + return err + } + + // The next byte (the 13th one) is the protocol version and command. + version, err := p.bufReader.ReadByte() + if err != nil { + return err + } + + // The 14th byte contains the transport protocol and address family.otocol. + familyByte, err := p.bufReader.ReadByte() + if err != nil { + return err + } + + // The 15th and 16th bytes is the address length in bytes in network endian order. + var addressLen uint16 + if err := binary.Read(p.bufReader, binary.BigEndian, &addressLen); err != nil { + return err + } + + // Now handle the version byte: (14th byte). + // The highest four bits contains the version. As of this specification, it must + // always be sent as \x2 and the receiver must only accept this value. + if version>>4 != 0x2 { + return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))} + } + + // The lowest four bits represents the command : + switch version & 0xf { + case 0x0: + // - \x0 : LOCAL : the connection was established on purpose by the proxy + // without being relayed. The connection endpoints are the sender and the + // receiver. Such connections exist when the proxy sends health-checks to the + // server. The receiver must accept this connection as valid and must use the + // real connection endpoints and discard the protocol block including the + // family which is ignored. + + // We therefore ignore the 14th, 15th and 16th bytes + p.remoteAddr = p.conn.LocalAddr() + p.localAddr = p.conn.RemoteAddr() + return nil + case 0x1: + // - \x1 : PROXY : the connection was established on behalf of another node, + // and reflects the original connection endpoints. The receiver must then use + // the information provided in the protocol block to get original the address. + default: + // - other values are unassigned and must not be emitted by senders. Receivers + // must drop connections presenting unexpected values here. + return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))} + } + + // Now handle the familyByte byte: (15th byte). + // The highest 4 bits contain the address family, the lowest 4 bits contain the protocol + + // The address family maps to the original socket family without necessarily + // matching the values internally used by the system. It may be one of : + // + // - 0x0 : AF_UNSPEC : the connection is forwarded for an unknown, unspecified + // or unsupported protocol. The sender should use this family when sending + // LOCAL commands or when dealing with unsupported protocol families. The + // receiver is free to accept the connection anyway and use the real endpoint + // addresses or to reject it. The receiver should ignore address information. + // + // - 0x1 : AF_INET : the forwarded connection uses the AF_INET address family + // (IPv4). The addresses are exactly 4 bytes each in network byte order, + // followed by transport protocol information (typically ports). + // + // - 0x2 : AF_INET6 : the forwarded connection uses the AF_INET6 address family + // (IPv6). The addresses are exactly 16 bytes each in network byte order, + // followed by transport protocol information (typically ports). + // + // - 0x3 : AF_UNIX : the forwarded connection uses the AF_UNIX address family + // (UNIX). The addresses are exactly 108 bytes each. + // + // - other values are unspecified and must not be emitted in version 2 of this + // protocol and must be rejected as invalid by receivers. + + // The transport protocol is specified in the lowest 4 bits of the 14th byte : + // + // - 0x0 : UNSPEC : the connection is forwarded for an unknown, unspecified + // or unsupported protocol. The sender should use this family when sending + // LOCAL commands or when dealing with unsupported protocol families. The + // receiver is free to accept the connection anyway and use the real endpoint + // addresses or to reject it. The receiver should ignore address information. + // + // - 0x1 : STREAM : the forwarded connection uses a SOCK_STREAM protocol (eg: + // TCP or UNIX_STREAM). When used with AF_INET/AF_INET6 (TCP), the addresses + // are followed by the source and destination ports represented on 2 bytes + // each in network byte order. + // + // - 0x2 : DGRAM : the forwarded connection uses a SOCK_DGRAM protocol (eg: + // UDP or UNIX_DGRAM). When used with AF_INET/AF_INET6 (UDP), the addresses + // are followed by the source and destination ports represented on 2 bytes + // each in network byte order. + // + // - other values are unspecified and must not be emitted in version 2 of this + // protocol and must be rejected as invalid by receivers. + + if familyByte>>4 == 0x0 || familyByte&0xf == 0x0 { + // - hi 0x0 : AF_UNSPEC : the connection is forwarded for an unknown address type + // or + // - lo 0x0 : UNSPEC : the connection is forwarded for an unspecified protocol + if !p.acceptUnknown { + p.conn.Close() + return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))} + } + p.remoteAddr = p.conn.LocalAddr() + p.localAddr = p.conn.RemoteAddr() + _, err = p.bufReader.Discard(int(addressLen)) + return err + } + + // other address or protocol + if (familyByte>>4) > 0x3 || (familyByte&0xf) > 0x2 { + return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))} + } + + // Handle AF_UNIX addresses + if familyByte>>4 == 0x3 { + // - \x31 : UNIX stream : the forwarded connection uses SOCK_STREAM over the + // AF_UNIX protocol family. Address length is 2*108 = 216 bytes. + // - \x32 : UNIX datagram : the forwarded connection uses SOCK_DGRAM over the + // AF_UNIX protocol family. Address length is 2*108 = 216 bytes. + if addressLen != 216 { + return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))} + } + remoteName := make([]byte, 108) + localName := make([]byte, 108) + if _, err := p.bufReader.Read(remoteName); err != nil { + return err + } + if _, err := p.bufReader.Read(localName); err != nil { + return err + } + protocol := "unix" + if familyByte&0xf == 2 { + protocol = "unixgram" + } + + p.remoteAddr = &net.UnixAddr{ + Name: string(remoteName), + Net: protocol, + } + p.localAddr = &net.UnixAddr{ + Name: string(localName), + Net: protocol, + } + return nil + } + + var remoteIP []byte + var localIP []byte + var remotePort uint16 + var localPort uint16 + + if familyByte>>4 == 0x1 { + // AF_INET + // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET + // protocol family. Address length is 2*4 + 2*2 = 12 bytes. + // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET + // protocol family. Address length is 2*4 + 2*2 = 12 bytes. + if addressLen != 12 { + return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))} + } + + remoteIP = make([]byte, 4) + localIP = make([]byte, 4) + } else { + // AF_INET6 + // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 + // protocol family. Address length is 2*16 + 2*2 = 36 bytes. + // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 + // protocol family. Address length is 2*16 + 2*2 = 36 bytes. + if addressLen != 36 { + return &ErrBadHeader{append(v2Prefix, version, familyByte, uint8(addressLen>>8), uint8(addressLen&0xff))} + } + + remoteIP = make([]byte, 16) + localIP = make([]byte, 16) + } + + if _, err := p.bufReader.Read(remoteIP); err != nil { + return err + } + if _, err := p.bufReader.Read(localIP); err != nil { + return err + } + if err := binary.Read(p.bufReader, binary.BigEndian, &remotePort); err != nil { + return err + } + if err := binary.Read(p.bufReader, binary.BigEndian, &localPort); err != nil { + return err + } + + if familyByte&0xf == 1 { + p.remoteAddr = &net.TCPAddr{ + IP: remoteIP, + Port: int(remotePort), + } + p.localAddr = &net.TCPAddr{ + IP: localIP, + Port: int(localPort), + } + } else { + p.remoteAddr = &net.UDPAddr{ + IP: remoteIP, + Port: int(remotePort), + } + p.localAddr = &net.UDPAddr{ + IP: localIP, + Port: int(localPort), + } + } + return nil +} + +func (p *Conn) readV1ProxyHeader() error { + // Read until a newline + header, err := p.bufReader.ReadString('\n') + if err != nil { + p.conn.Close() + return err + } + + if header[len(header)-2] != '\r' { + return &ErrBadHeader{[]byte(header)} + } + + // Strip the carriage return and new line + header = header[:len(header)-2] + + // Split on spaces, should be (PROXY <type> <remote addr> <local addr> <remote port> <local port>) + parts := strings.Split(header, " ") + if len(parts) < 2 { + p.conn.Close() + return &ErrBadHeader{[]byte(header)} + } + + // Verify the type is known + switch parts[1] { + case "UNKNOWN": + if !p.acceptUnknown || len(parts) != 2 { + p.conn.Close() + return &ErrBadHeader{[]byte(header)} + } + p.remoteAddr = p.conn.LocalAddr() + p.localAddr = p.conn.RemoteAddr() + return nil + case "TCP4": + case "TCP6": + default: + p.conn.Close() + return &ErrBadAddressType{parts[1]} + } + + if len(parts) != 6 { + p.conn.Close() + return &ErrBadHeader{[]byte(header)} + } + + // Parse out the remote address + ip := net.ParseIP(parts[2]) + if ip == nil { + p.conn.Close() + return &ErrBadRemote{parts[2], parts[4]} + } + port, err := strconv.Atoi(parts[4]) + if err != nil { + p.conn.Close() + return &ErrBadRemote{parts[2], parts[4]} + } + p.remoteAddr = &net.TCPAddr{IP: ip, Port: port} + + // Parse out the destination address + ip = net.ParseIP(parts[3]) + if ip == nil { + p.conn.Close() + return &ErrBadLocal{parts[3], parts[5]} + } + port, err = strconv.Atoi(parts[5]) + if err != nil { + p.conn.Close() + return &ErrBadLocal{parts[3], parts[5]} + } + p.localAddr = &net.TCPAddr{IP: ip, Port: port} + + return nil +} diff --git a/modules/proxyprotocol/errors.go b/modules/proxyprotocol/errors.go new file mode 100644 index 0000000000..2acf9d84b0 --- /dev/null +++ b/modules/proxyprotocol/errors.go @@ -0,0 +1,45 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package proxyprotocol + +import "fmt" + +// ErrBadHeader is an error demonstrating a bad proxy header +type ErrBadHeader struct { + Header []byte +} + +func (e *ErrBadHeader) Error() string { + return fmt.Sprintf("Unexpected proxy header: %v", e.Header) +} + +// ErrBadAddressType is an error demonstrating a bad proxy header with bad Address type +type ErrBadAddressType struct { + Address string +} + +func (e *ErrBadAddressType) Error() string { + return fmt.Sprintf("Unexpected proxy header address type: %s", e.Address) +} + +// ErrBadRemote is an error demonstrating a bad proxy header with bad Remote +type ErrBadRemote struct { + IP string + Port string +} + +func (e *ErrBadRemote) Error() string { + return fmt.Sprintf("Unexpected proxy header remote IP and port: %s %s", e.IP, e.Port) +} + +// ErrBadLocal is an error demonstrating a bad proxy header with bad Local +type ErrBadLocal struct { + IP string + Port string +} + +func (e *ErrBadLocal) Error() string { + return fmt.Sprintf("Unexpected proxy header local IP and port: %s %s", e.IP, e.Port) +} diff --git a/modules/proxyprotocol/listener.go b/modules/proxyprotocol/listener.go new file mode 100644 index 0000000000..64d9b323e5 --- /dev/null +++ b/modules/proxyprotocol/listener.go @@ -0,0 +1,47 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package proxyprotocol + +import ( + "net" + "time" +) + +// Listener is used to wrap an underlying listener, +// whose connections may be using the HAProxy Proxy Protocol (version 1 or 2). +// If the connection is using the protocol, the RemoteAddr() will return +// the correct client address. +// +// Optionally define ProxyHeaderTimeout to set a maximum time to +// receive the Proxy Protocol Header. Zero means no timeout. +type Listener struct { + Listener net.Listener + ProxyHeaderTimeout time.Duration + AcceptUnknown bool // allow PROXY UNKNOWN +} + +// Accept implements the Accept method in the Listener interface +// it waits for the next call and returns a wrapped Conn. +func (p *Listener) Accept() (net.Conn, error) { + // Get the underlying connection + conn, err := p.Listener.Accept() + if err != nil { + return nil, err + } + + newConn := NewConn(conn, p.ProxyHeaderTimeout) + newConn.acceptUnknown = p.AcceptUnknown + return newConn, nil +} + +// Close closes the underlying listener. +func (p *Listener) Close() error { + return p.Listener.Close() +} + +// Addr returns the underlying listener's network address. +func (p *Listener) Addr() net.Addr { + return p.Listener.Addr() +} diff --git a/modules/proxyprotocol/util.go b/modules/proxyprotocol/util.go new file mode 100644 index 0000000000..b12771b686 --- /dev/null +++ b/modules/proxyprotocol/util.go @@ -0,0 +1,15 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package proxyprotocol + +import "io" + +var localHeader = append(v2Prefix, '\x20', '\x00', '\x00', '\x00', '\x00') + +// WriteLocalHeader will write the ProxyProtocol Header for a local connection to the provided writer +func WriteLocalHeader(w io.Writer) error { + _, err := w.Write(localHeader) + return err +} diff --git a/modules/setting/setting.go b/modules/setting/setting.go index 3ab25fef6e..931b6523ea 100644 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -94,45 +94,52 @@ var ( LocalURL string // Server settings - Protocol Scheme - Domain string - HTTPAddr string - HTTPPort string - RedirectOtherPort bool - PortToRedirect string - OfflineMode bool - CertFile string - KeyFile string - StaticRootPath string - StaticCacheTime time.Duration - EnableGzip bool - LandingPageURL LandingPage - LandingPageCustom string - UnixSocketPermission uint32 - EnablePprof bool - PprofDataPath string - EnableAcme bool - AcmeTOS bool - AcmeLiveDirectory string - AcmeEmail string - AcmeURL string - AcmeCARoot string - SSLMinimumVersion string - SSLMaximumVersion string - SSLCurvePreferences []string - SSLCipherSuites []string - GracefulRestartable bool - GracefulHammerTime time.Duration - StartupTimeout time.Duration - PerWriteTimeout = 30 * time.Second - PerWritePerKbTimeout = 10 * time.Second - StaticURLPrefix string - AbsoluteAssetURL string + Protocol Scheme + UseProxyProtocol bool // `ini:"USE_PROXY_PROTOCOL"` + ProxyProtocolTLSBridging bool //`ini:"PROXY_PROTOCOL_TLS_BRIDGING"` + ProxyProtocolHeaderTimeout time.Duration + ProxyProtocolAcceptUnknown bool + Domain string + HTTPAddr string + HTTPPort string + LocalUseProxyProtocol bool + RedirectOtherPort bool + RedirectorUseProxyProtocol bool + PortToRedirect string + OfflineMode bool + CertFile string + KeyFile string + StaticRootPath string + StaticCacheTime time.Duration + EnableGzip bool + LandingPageURL LandingPage + LandingPageCustom string + UnixSocketPermission uint32 + EnablePprof bool + PprofDataPath string + EnableAcme bool + AcmeTOS bool + AcmeLiveDirectory string + AcmeEmail string + AcmeURL string + AcmeCARoot string + SSLMinimumVersion string + SSLMaximumVersion string + SSLCurvePreferences []string + SSLCipherSuites []string + GracefulRestartable bool + GracefulHammerTime time.Duration + StartupTimeout time.Duration + PerWriteTimeout = 30 * time.Second + PerWritePerKbTimeout = 10 * time.Second + StaticURLPrefix string + AbsoluteAssetURL string SSH = struct { Disabled bool `ini:"DISABLE_SSH"` StartBuiltinServer bool `ini:"START_SSH_SERVER"` BuiltinServerUser string `ini:"BUILTIN_SSH_SERVER_USER"` + UseProxyProtocol bool `ini:"SSH_SERVER_USE_PROXY_PROTOCOL"` Domain string `ini:"SSH_DOMAIN"` Port int `ini:"SSH_PORT"` User string `ini:"SSH_USER"` @@ -717,6 +724,10 @@ func loadFromConf(allowEmpty bool, extraConfig string) { HTTPAddr = filepath.Join(AppWorkPath, HTTPAddr) } } + UseProxyProtocol = sec.Key("USE_PROXY_PROTOCOL").MustBool(false) + ProxyProtocolTLSBridging = sec.Key("PROXY_PROTOCOL_TLS_BRIDGING").MustBool(false) + ProxyProtocolHeaderTimeout = sec.Key("PROXY_PROTOCOL_HEADER_TIMEOUT").MustDuration(5 * time.Second) + ProxyProtocolAcceptUnknown = sec.Key("PROXY_PROTOCOL_ACCEPT_UNKNOWN").MustBool(false) GracefulRestartable = sec.Key("ALLOW_GRACEFUL_RESTARTS").MustBool(true) GracefulHammerTime = sec.Key("GRACEFUL_HAMMER_TIME").MustDuration(60 * time.Second) StartupTimeout = sec.Key("STARTUP_TIMEOUT").MustDuration(0 * time.Second) @@ -770,8 +781,10 @@ func loadFromConf(allowEmpty bool, extraConfig string) { } LocalURL = sec.Key("LOCAL_ROOT_URL").MustString(defaultLocalURL) LocalURL = strings.TrimRight(LocalURL, "/") + "/" + LocalUseProxyProtocol = sec.Key("LOCAL_USE_PROXY_PROTOCOL").MustBool(UseProxyProtocol) RedirectOtherPort = sec.Key("REDIRECT_OTHER_PORT").MustBool(false) PortToRedirect = sec.Key("PORT_TO_REDIRECT").MustString("80") + RedirectorUseProxyProtocol = sec.Key("REDIRECTOR_USE_PROXY_PROTOCOL").MustBool(UseProxyProtocol) OfflineMode = sec.Key("OFFLINE_MODE").MustBool() DisableRouterLog = sec.Key("DISABLE_ROUTER_LOG").MustBool() if len(StaticRootPath) == 0 { @@ -836,6 +849,7 @@ func loadFromConf(allowEmpty bool, extraConfig string) { SSH.KeygenPath = sec.Key("SSH_KEYGEN_PATH").MustString("ssh-keygen") SSH.Port = sec.Key("SSH_PORT").MustInt(22) SSH.ListenPort = sec.Key("SSH_LISTEN_PORT").MustInt(SSH.Port) + SSH.UseProxyProtocol = sec.Key("SSH_SERVER_USE_PROXY_PROTOCOL").MustBool(false) // When disable SSH, start builtin server value is ignored. if SSH.Disabled { diff --git a/modules/ssh/ssh_graceful.go b/modules/ssh/ssh_graceful.go index 9b91baf09e..166ea0b982 100644 --- a/modules/ssh/ssh_graceful.go +++ b/modules/ssh/ssh_graceful.go @@ -17,7 +17,7 @@ func listen(server *ssh.Server) { gracefulServer.PerWriteTimeout = setting.SSH.PerWriteTimeout gracefulServer.PerWritePerKbTimeout = setting.SSH.PerWritePerKbTimeout - err := gracefulServer.ListenAndServe(server.Serve) + err := gracefulServer.ListenAndServe(server.Serve, setting.SSH.UseProxyProtocol) if err != nil { select { case <-graceful.GetManager().IsShutdown(): |