diff options
Diffstat (limited to 'vendor/github.com/miekg/dns/server.go')
-rw-r--r-- | vendor/github.com/miekg/dns/server.go | 120 |
1 files changed, 92 insertions, 28 deletions
diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go index 3cf1a02401..b2a63bda49 100644 --- a/vendor/github.com/miekg/dns/server.go +++ b/vendor/github.com/miekg/dns/server.go @@ -72,13 +72,22 @@ type response struct { tsigStatus error tsigRequestMAC string tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used + udp net.PacketConn // i/o connection if UDP was used tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right + pcSession net.Addr // address to use when writing to a generic net.PacketConn writer Writer // writer to output the raw DNS bits } +// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets. +func handleRefused(w ResponseWriter, r *Msg) { + m := new(Msg) + m.SetRcode(r, RcodeRefused) + w.WriteMsg(m) +} + // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets. +// Deprecated: This function is going away. func HandleFailed(w ResponseWriter, r *Msg) { m := new(Msg) m.SetRcode(r, RcodeServerFailure) @@ -139,12 +148,24 @@ type Reader interface { ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) } -// defaultReader is an adapter for the Server struct that implements the Reader interface -// using the readTCP and readUDP func of the embedded Server. +// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns. +type PacketConnReader interface { + Reader + + // ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may + // alter connection properties, for example the read-deadline. + ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) +} + +// defaultReader is an adapter for the Server struct that implements the Reader and +// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs +// of the embedded Server. type defaultReader struct { *Server } +var _ PacketConnReader = defaultReader{} + func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { return dr.readTCP(conn, timeout) } @@ -153,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt return dr.readUDP(conn, timeout) } +func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + return dr.readPacketConn(conn, timeout) +} + // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader. // Implementations should never return a nil Reader. +// Readers should also implement the optional PacketConnReader interface. +// PacketConnReader is required to use a generic net.PacketConn. type DecorateReader func(Reader) Reader // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer. @@ -294,6 +321,7 @@ func (srv *Server) ListenAndServe() error { } u := l.(*net.UDPConn) if e := setUDPSocketOptions(u); e != nil { + u.Close() return e } srv.PacketConn = l @@ -317,24 +345,22 @@ func (srv *Server) ActivateAndServe() error { srv.init() - pConn := srv.PacketConn - l := srv.Listener - if pConn != nil { + if srv.PacketConn != nil { // Check PacketConn interface's type is valid and value // is not nil - if t, ok := pConn.(*net.UDPConn); ok && t != nil { + if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil { if e := setUDPSocketOptions(t); e != nil { return e } - srv.started = true - unlock() - return srv.serveUDP(t) } + srv.started = true + unlock() + return srv.serveUDP(srv.PacketConn) } - if l != nil { + if srv.Listener != nil { srv.started = true unlock() - return srv.serveTCP(l) + return srv.serveTCP(srv.Listener) } return &Error{err: "bad listeners"} } @@ -438,18 +464,24 @@ func (srv *Server) serveTCP(l net.Listener) error { } // serveUDP starts a UDP listener for the server. -func (srv *Server) serveUDP(l *net.UDPConn) error { +func (srv *Server) serveUDP(l net.PacketConn) error { defer l.Close() - if srv.NotifyStartedFunc != nil { - srv.NotifyStartedFunc() - } - reader := Reader(defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader) } + lUDP, isUDP := l.(*net.UDPConn) + readerPC, canPacketConn := reader.(PacketConnReader) + if !isUDP && !canPacketConn { + return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"} + } + + if srv.NotifyStartedFunc != nil { + srv.NotifyStartedFunc() + } + var wg sync.WaitGroup defer func() { wg.Wait() @@ -459,7 +491,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { rtimeout := srv.getReadTimeout() // deadline is not used here for srv.isStarted() { - m, s, err := reader.ReadUDP(l, rtimeout) + var ( + m []byte + sPC net.Addr + sUDP *SessionUDP + err error + ) + if isUDP { + m, sUDP, err = reader.ReadUDP(lUDP, rtimeout) + } else { + m, sPC, err = readerPC.ReadPacketConn(l, rtimeout) + } if err != nil { if !srv.isStarted() { return nil @@ -476,7 +518,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { continue } wg.Add(1) - go srv.serveUDPPacket(&wg, m, l, s) + go srv.serveUDPPacket(&wg, m, l, sUDP, sPC) } return nil @@ -538,8 +580,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) { } // Serve a new UDP request. -func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s} +func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { @@ -651,6 +693,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S return m, s, nil } +func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + srv.lock.RLock() + if srv.started { + // See the comment in readTCP above. + conn.SetReadDeadline(time.Now().Add(timeout)) + } + srv.lock.RUnlock() + + m := srv.udpPool.Get().([]byte) + n, addr, err := conn.ReadFrom(m) + if err != nil { + srv.udpPool.Put(m) + return nil, nil, err + } + m = m[:n] + return m, addr, nil +} + // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { if w.closed { @@ -684,17 +744,19 @@ func (w *response) Write(m []byte) (int, error) { switch { case w.udp != nil: - return WriteToSessionUDP(w.udp, m, w.udpSession) + if u, ok := w.udp.(*net.UDPConn); ok { + return WriteToSessionUDP(u, m, w.udpSession) + } + return w.udp.WriteTo(m, w.pcSession) case w.tcp != nil: if len(m) > MaxMsgSize { return 0, &Error{err: "message too large"} } - l := make([]byte, 2) - binary.BigEndian.PutUint16(l, uint16(len(m))) - - n, err := (&net.Buffers{l, m}).WriteTo(w.tcp) - return int(n), err + msg := make([]byte, 2+len(m)) + binary.BigEndian.PutUint16(msg, uint16(len(m))) + copy(msg[2:], m) + return w.tcp.Write(msg) default: panic("dns: internal error: udp and tcp both nil") } @@ -717,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr { switch { case w.udpSession != nil: return w.udpSession.RemoteAddr() + case w.pcSession != nil: + return w.pcSession case w.tcp != nil: return w.tcp.RemoteAddr() default: - panic("dns: internal error: udpSession and tcp both nil") + panic("dns: internal error: udpSession, pcSession and tcp are all nil") } } |