summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/caddyserver/certmagic/dnsutil.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/caddyserver/certmagic/dnsutil.go')
-rw-r--r--vendor/github.com/caddyserver/certmagic/dnsutil.go339
1 files changed, 339 insertions, 0 deletions
diff --git a/vendor/github.com/caddyserver/certmagic/dnsutil.go b/vendor/github.com/caddyserver/certmagic/dnsutil.go
new file mode 100644
index 0000000000..85f7714a80
--- /dev/null
+++ b/vendor/github.com/caddyserver/certmagic/dnsutil.go
@@ -0,0 +1,339 @@
+package certmagic
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Code in this file adapted from go-acme/lego, July 2020:
+// https://github.com/go-acme/lego
+// by Ludovic Fernandez and Dominik Menke
+//
+// It has been modified.
+
+// findZoneByFQDN determines the zone apex for the given fqdn by recursing
+// up the domain labels until the nameserver returns a SOA record in the
+// answer section.
+func findZoneByFQDN(fqdn string, nameservers []string) (string, error) {
+ if !strings.HasSuffix(fqdn, ".") {
+ fqdn += "."
+ }
+ soa, err := lookupSoaByFqdn(fqdn, nameservers)
+ if err != nil {
+ return "", err
+ }
+ return soa.zone, nil
+}
+
+func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
+ if !strings.HasSuffix(fqdn, ".") {
+ fqdn += "."
+ }
+
+ fqdnSOACacheMu.Lock()
+ defer fqdnSOACacheMu.Unlock()
+
+ // prefer cached version if fresh
+ if ent := fqdnSOACache[fqdn]; ent != nil && !ent.isExpired() {
+ return ent, nil
+ }
+
+ ent, err := fetchSoaByFqdn(fqdn, nameservers)
+ if err != nil {
+ return nil, err
+ }
+
+ // save result to cache, but don't allow
+ // the cache to grow out of control
+ if len(fqdnSOACache) >= 1000 {
+ for key := range fqdnSOACache {
+ delete(fqdnSOACache, key)
+ break
+ }
+ }
+ fqdnSOACache[fqdn] = ent
+
+ return ent, nil
+}
+
+func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) {
+ var err error
+ var in *dns.Msg
+
+ labelIndexes := dns.Split(fqdn)
+ for _, index := range labelIndexes {
+ domain := fqdn[index:]
+
+ in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
+ if err != nil {
+ continue
+ }
+ if in == nil {
+ continue
+ }
+
+ switch in.Rcode {
+ case dns.RcodeSuccess:
+ // Check if we got a SOA RR in the answer section
+ if len(in.Answer) == 0 {
+ continue
+ }
+
+ // CNAME records cannot/should not exist at the root of a zone.
+ // So we skip a domain when a CNAME is found.
+ if dnsMsgContainsCNAME(in) {
+ continue
+ }
+
+ for _, ans := range in.Answer {
+ if soa, ok := ans.(*dns.SOA); ok {
+ return newSoaCacheEntry(soa), nil
+ }
+ }
+ case dns.RcodeNameError:
+ // NXDOMAIN
+ default:
+ // Any response code other than NOERROR and NXDOMAIN is treated as error
+ return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain)
+ }
+ }
+
+ return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err))
+}
+
+// dnsMsgContainsCNAME checks for a CNAME answer in msg
+func dnsMsgContainsCNAME(msg *dns.Msg) bool {
+ for _, ans := range msg.Answer {
+ if _, ok := ans.(*dns.CNAME); ok {
+ return true
+ }
+ }
+ return false
+}
+
+func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
+ m := createDNSMsg(fqdn, rtype, recursive)
+ var in *dns.Msg
+ var err error
+ for _, ns := range nameservers {
+ in, err = sendDNSQuery(m, ns)
+ if err == nil && len(in.Answer) > 0 {
+ break
+ }
+ }
+ return in, err
+}
+
+func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion(fqdn, rtype)
+ m.SetEdns0(4096, false)
+ if !recursive {
+ m.RecursionDesired = false
+ }
+ return m
+}
+
+func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
+ udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
+ in, _, err := udp.Exchange(m, ns)
+ // two kinds of errors we can handle by retrying with TCP:
+ // truncation and timeout; see https://github.com/caddyserver/caddy/issues/3639
+ truncated := in != nil && in.Truncated
+ timeoutErr := err != nil && strings.Contains(err.Error(), "timeout")
+ if truncated || timeoutErr {
+ tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
+ in, _, err = tcp.Exchange(m, ns)
+ }
+ return in, err
+}
+
+func formatDNSError(msg *dns.Msg, err error) string {
+ var parts []string
+ if msg != nil {
+ parts = append(parts, dns.RcodeToString[msg.Rcode])
+ }
+ if err != nil {
+ parts = append(parts, err.Error())
+ }
+ if len(parts) > 0 {
+ return ": " + strings.Join(parts, " ")
+ }
+ return ""
+}
+
+// soaCacheEntry holds a cached SOA record (only selected fields)
+type soaCacheEntry struct {
+ zone string // zone apex (a domain name)
+ primaryNs string // primary nameserver for the zone apex
+ expires time.Time // time when this cache entry should be evicted
+}
+
+func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry {
+ return &soaCacheEntry{
+ zone: soa.Hdr.Name,
+ primaryNs: soa.Ns,
+ expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second),
+ }
+}
+
+// isExpired checks whether a cache entry should be considered expired.
+func (cache *soaCacheEntry) isExpired() bool {
+ return time.Now().After(cache.expires)
+}
+
+// systemOrDefaultNameservers attempts to get system nameservers from the
+// resolv.conf file given by path before falling back to hard-coded defaults.
+func systemOrDefaultNameservers(path string, defaults []string) []string {
+ config, err := dns.ClientConfigFromFile(path)
+ if err != nil || len(config.Servers) == 0 {
+ return defaults
+ }
+ return config.Servers
+}
+
+// populateNameserverPorts ensures that all nameservers have a port number.
+func populateNameserverPorts(servers []string) {
+ for i := range servers {
+ _, port, _ := net.SplitHostPort(servers[i])
+ if port == "" {
+ servers[i] = net.JoinHostPort(servers[i], "53")
+ }
+ }
+}
+
+// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
+func checkDNSPropagation(fqdn, value string, resolvers []string) (bool, error) {
+ if !strings.HasSuffix(fqdn, ".") {
+ fqdn += "."
+ }
+
+ // Initial attempt to resolve at the recursive NS
+ r, err := dnsQuery(fqdn, dns.TypeTXT, resolvers, true)
+ if err != nil {
+ return false, err
+ }
+
+ // TODO: make this configurable, maybe
+ // if !p.requireCompletePropagation {
+ // return true, nil
+ // }
+
+ if r.Rcode == dns.RcodeSuccess {
+ fqdn = updateDomainWithCName(r, fqdn)
+ }
+
+ authoritativeNss, err := lookupNameservers(fqdn, resolvers)
+ if err != nil {
+ return false, err
+ }
+
+ return checkAuthoritativeNss(fqdn, value, authoritativeNss)
+}
+
+// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
+func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
+ for _, ns := range nameservers {
+ r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
+ if err != nil {
+ return false, err
+ }
+
+ if r.Rcode != dns.RcodeSuccess {
+ if r.Rcode == dns.RcodeNameError {
+ // if Present() succeeded, then it must show up eventually, or else
+ // something is really broken in the DNS provider or their API;
+ // no need for error here, simply have the caller try again
+ return false, nil
+ }
+ return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
+ }
+
+ var found bool
+ for _, rr := range r.Answer {
+ if txt, ok := rr.(*dns.TXT); ok {
+ record := strings.Join(txt.Txt, "")
+ if record == value {
+ found = true
+ break
+ }
+ }
+ }
+
+ if !found {
+ return false, nil
+ }
+ }
+
+ return true, nil
+}
+
+// lookupNameservers returns the authoritative nameservers for the given fqdn.
+func lookupNameservers(fqdn string, resolvers []string) ([]string, error) {
+ var authoritativeNss []string
+
+ zone, err := findZoneByFQDN(fqdn, resolvers)
+ if err != nil {
+ return nil, fmt.Errorf("could not determine the zone: %w", err)
+ }
+
+ r, err := dnsQuery(zone, dns.TypeNS, resolvers, true)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, rr := range r.Answer {
+ if ns, ok := rr.(*dns.NS); ok {
+ authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
+ }
+ }
+
+ if len(authoritativeNss) > 0 {
+ return authoritativeNss, nil
+ }
+ return nil, errors.New("could not determine authoritative nameservers")
+}
+
+// Update FQDN with CNAME if any
+func updateDomainWithCName(r *dns.Msg, fqdn string) string {
+ for _, rr := range r.Answer {
+ if cn, ok := rr.(*dns.CNAME); ok {
+ if cn.Hdr.Name == fqdn {
+ return cn.Target
+ }
+ }
+ }
+ return fqdn
+}
+
+// recursiveNameservers are used to pre-check DNS propagation. It
+// prepends user-configured nameservers (custom) to the defaults
+// obtained from resolv.conf and defaultNameservers and ensures
+// that all server addresses have a port value.
+func recursiveNameservers(custom []string) []string {
+ servers := append(custom, systemOrDefaultNameservers(defaultResolvConf, defaultNameservers)...)
+ populateNameserverPorts(servers)
+ return servers
+}
+
+var defaultNameservers = []string{
+ "8.8.8.8:53",
+ "8.8.4.4:53",
+ "1.1.1.1:53",
+ "1.0.0.1:53",
+}
+
+var dnsTimeout = 10 * time.Second
+
+var (
+ fqdnSOACache = map[string]*soaCacheEntry{}
+ fqdnSOACacheMu sync.Mutex
+)
+
+const defaultResolvConf = "/etc/resolv.conf"