aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/server.go')
-rw-r--r--vendor/github.com/miekg/dns/server.go120
1 files changed, 92 insertions, 28 deletions
diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go
index 3cf1a02401..b2a63bda49 100644
--- a/vendor/github.com/miekg/dns/server.go
+++ b/vendor/github.com/miekg/dns/server.go
@@ -72,13 +72,22 @@ type response struct {
tsigStatus error
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
- udp *net.UDPConn // i/o connection if UDP was used
+ udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
+ pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
}
+// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
+func handleRefused(w ResponseWriter, r *Msg) {
+ m := new(Msg)
+ m.SetRcode(r, RcodeRefused)
+ w.WriteMsg(m)
+}
+
// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
+// Deprecated: This function is going away.
func HandleFailed(w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeServerFailure)
@@ -139,12 +148,24 @@ type Reader interface {
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
}
-// defaultReader is an adapter for the Server struct that implements the Reader interface
-// using the readTCP and readUDP func of the embedded Server.
+// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
+type PacketConnReader interface {
+ Reader
+
+ // ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
+ // alter connection properties, for example the read-deadline.
+ ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
+}
+
+// defaultReader is an adapter for the Server struct that implements the Reader and
+// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
+// of the embedded Server.
type defaultReader struct {
*Server
}
+var _ PacketConnReader = defaultReader{}
+
func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
return dr.readTCP(conn, timeout)
}
@@ -153,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt
return dr.readUDP(conn, timeout)
}
+func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+ return dr.readPacketConn(conn, timeout)
+}
+
// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
// Implementations should never return a nil Reader.
+// Readers should also implement the optional PacketConnReader interface.
+// PacketConnReader is required to use a generic net.PacketConn.
type DecorateReader func(Reader) Reader
// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
@@ -294,6 +321,7 @@ func (srv *Server) ListenAndServe() error {
}
u := l.(*net.UDPConn)
if e := setUDPSocketOptions(u); e != nil {
+ u.Close()
return e
}
srv.PacketConn = l
@@ -317,24 +345,22 @@ func (srv *Server) ActivateAndServe() error {
srv.init()
- pConn := srv.PacketConn
- l := srv.Listener
- if pConn != nil {
+ if srv.PacketConn != nil {
// Check PacketConn interface's type is valid and value
// is not nil
- if t, ok := pConn.(*net.UDPConn); ok && t != nil {
+ if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
if e := setUDPSocketOptions(t); e != nil {
return e
}
- srv.started = true
- unlock()
- return srv.serveUDP(t)
}
+ srv.started = true
+ unlock()
+ return srv.serveUDP(srv.PacketConn)
}
- if l != nil {
+ if srv.Listener != nil {
srv.started = true
unlock()
- return srv.serveTCP(l)
+ return srv.serveTCP(srv.Listener)
}
return &Error{err: "bad listeners"}
}
@@ -438,18 +464,24 @@ func (srv *Server) serveTCP(l net.Listener) error {
}
// serveUDP starts a UDP listener for the server.
-func (srv *Server) serveUDP(l *net.UDPConn) error {
+func (srv *Server) serveUDP(l net.PacketConn) error {
defer l.Close()
- if srv.NotifyStartedFunc != nil {
- srv.NotifyStartedFunc()
- }
-
reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
+ lUDP, isUDP := l.(*net.UDPConn)
+ readerPC, canPacketConn := reader.(PacketConnReader)
+ if !isUDP && !canPacketConn {
+ return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
+ }
+
+ if srv.NotifyStartedFunc != nil {
+ srv.NotifyStartedFunc()
+ }
+
var wg sync.WaitGroup
defer func() {
wg.Wait()
@@ -459,7 +491,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
rtimeout := srv.getReadTimeout()
// deadline is not used here
for srv.isStarted() {
- m, s, err := reader.ReadUDP(l, rtimeout)
+ var (
+ m []byte
+ sPC net.Addr
+ sUDP *SessionUDP
+ err error
+ )
+ if isUDP {
+ m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
+ } else {
+ m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
+ }
if err != nil {
if !srv.isStarted() {
return nil
@@ -476,7 +518,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
continue
}
wg.Add(1)
- go srv.serveUDPPacket(&wg, m, l, s)
+ go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
}
return nil
@@ -538,8 +580,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
}
// Serve a new UDP request.
-func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
- w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
+func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
+ w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
@@ -651,6 +693,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
return m, s, nil
}
+func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+ srv.lock.RLock()
+ if srv.started {
+ // See the comment in readTCP above.
+ conn.SetReadDeadline(time.Now().Add(timeout))
+ }
+ srv.lock.RUnlock()
+
+ m := srv.udpPool.Get().([]byte)
+ n, addr, err := conn.ReadFrom(m)
+ if err != nil {
+ srv.udpPool.Put(m)
+ return nil, nil, err
+ }
+ m = m[:n]
+ return m, addr, nil
+}
+
// WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) {
if w.closed {
@@ -684,17 +744,19 @@ func (w *response) Write(m []byte) (int, error) {
switch {
case w.udp != nil:
- return WriteToSessionUDP(w.udp, m, w.udpSession)
+ if u, ok := w.udp.(*net.UDPConn); ok {
+ return WriteToSessionUDP(u, m, w.udpSession)
+ }
+ return w.udp.WriteTo(m, w.pcSession)
case w.tcp != nil:
if len(m) > MaxMsgSize {
return 0, &Error{err: "message too large"}
}
- l := make([]byte, 2)
- binary.BigEndian.PutUint16(l, uint16(len(m)))
-
- n, err := (&net.Buffers{l, m}).WriteTo(w.tcp)
- return int(n), err
+ msg := make([]byte, 2+len(m))
+ binary.BigEndian.PutUint16(msg, uint16(len(m)))
+ copy(msg[2:], m)
+ return w.tcp.Write(msg)
default:
panic("dns: internal error: udp and tcp both nil")
}
@@ -717,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr {
switch {
case w.udpSession != nil:
return w.udpSession.RemoteAddr()
+ case w.pcSession != nil:
+ return w.pcSession
case w.tcp != nil:
return w.tcp.RemoteAddr()
default:
- panic("dns: internal error: udpSession and tcp both nil")
+ panic("dns: internal error: udpSession, pcSession and tcp are all nil")
}
}