diff options
Diffstat (limited to 'vendor/github.com/caddyserver/certmagic/dnsutil.go')
-rw-r--r-- | vendor/github.com/caddyserver/certmagic/dnsutil.go | 339 |
1 files changed, 339 insertions, 0 deletions
diff --git a/vendor/github.com/caddyserver/certmagic/dnsutil.go b/vendor/github.com/caddyserver/certmagic/dnsutil.go new file mode 100644 index 0000000000..85f7714a80 --- /dev/null +++ b/vendor/github.com/caddyserver/certmagic/dnsutil.go @@ -0,0 +1,339 @@ +package certmagic + +import ( + "errors" + "fmt" + "net" + "strings" + "sync" + "time" + + "github.com/miekg/dns" +) + +// Code in this file adapted from go-acme/lego, July 2020: +// https://github.com/go-acme/lego +// by Ludovic Fernandez and Dominik Menke +// +// It has been modified. + +// findZoneByFQDN determines the zone apex for the given fqdn by recursing +// up the domain labels until the nameserver returns a SOA record in the +// answer section. +func findZoneByFQDN(fqdn string, nameservers []string) (string, error) { + if !strings.HasSuffix(fqdn, ".") { + fqdn += "." + } + soa, err := lookupSoaByFqdn(fqdn, nameservers) + if err != nil { + return "", err + } + return soa.zone, nil +} + +func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { + if !strings.HasSuffix(fqdn, ".") { + fqdn += "." + } + + fqdnSOACacheMu.Lock() + defer fqdnSOACacheMu.Unlock() + + // prefer cached version if fresh + if ent := fqdnSOACache[fqdn]; ent != nil && !ent.isExpired() { + return ent, nil + } + + ent, err := fetchSoaByFqdn(fqdn, nameservers) + if err != nil { + return nil, err + } + + // save result to cache, but don't allow + // the cache to grow out of control + if len(fqdnSOACache) >= 1000 { + for key := range fqdnSOACache { + delete(fqdnSOACache, key) + break + } + } + fqdnSOACache[fqdn] = ent + + return ent, nil +} + +func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { + var err error + var in *dns.Msg + + labelIndexes := dns.Split(fqdn) + for _, index := range labelIndexes { + domain := fqdn[index:] + + in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true) + if err != nil { + continue + } + if in == nil { + continue + } + + switch in.Rcode { + case dns.RcodeSuccess: + // Check if we got a SOA RR in the answer section + if len(in.Answer) == 0 { + continue + } + + // CNAME records cannot/should not exist at the root of a zone. + // So we skip a domain when a CNAME is found. + if dnsMsgContainsCNAME(in) { + continue + } + + for _, ans := range in.Answer { + if soa, ok := ans.(*dns.SOA); ok { + return newSoaCacheEntry(soa), nil + } + } + case dns.RcodeNameError: + // NXDOMAIN + default: + // Any response code other than NOERROR and NXDOMAIN is treated as error + return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain) + } + } + + return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err)) +} + +// dnsMsgContainsCNAME checks for a CNAME answer in msg +func dnsMsgContainsCNAME(msg *dns.Msg) bool { + for _, ans := range msg.Answer { + if _, ok := ans.(*dns.CNAME); ok { + return true + } + } + return false +} + +func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) { + m := createDNSMsg(fqdn, rtype, recursive) + var in *dns.Msg + var err error + for _, ns := range nameservers { + in, err = sendDNSQuery(m, ns) + if err == nil && len(in.Answer) > 0 { + break + } + } + return in, err +} + +func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(fqdn, rtype) + m.SetEdns0(4096, false) + if !recursive { + m.RecursionDesired = false + } + return m +} + +func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) { + udp := &dns.Client{Net: "udp", Timeout: dnsTimeout} + in, _, err := udp.Exchange(m, ns) + // two kinds of errors we can handle by retrying with TCP: + // truncation and timeout; see https://github.com/caddyserver/caddy/issues/3639 + truncated := in != nil && in.Truncated + timeoutErr := err != nil && strings.Contains(err.Error(), "timeout") + if truncated || timeoutErr { + tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} + in, _, err = tcp.Exchange(m, ns) + } + return in, err +} + +func formatDNSError(msg *dns.Msg, err error) string { + var parts []string + if msg != nil { + parts = append(parts, dns.RcodeToString[msg.Rcode]) + } + if err != nil { + parts = append(parts, err.Error()) + } + if len(parts) > 0 { + return ": " + strings.Join(parts, " ") + } + return "" +} + +// soaCacheEntry holds a cached SOA record (only selected fields) +type soaCacheEntry struct { + zone string // zone apex (a domain name) + primaryNs string // primary nameserver for the zone apex + expires time.Time // time when this cache entry should be evicted +} + +func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry { + return &soaCacheEntry{ + zone: soa.Hdr.Name, + primaryNs: soa.Ns, + expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second), + } +} + +// isExpired checks whether a cache entry should be considered expired. +func (cache *soaCacheEntry) isExpired() bool { + return time.Now().After(cache.expires) +} + +// systemOrDefaultNameservers attempts to get system nameservers from the +// resolv.conf file given by path before falling back to hard-coded defaults. +func systemOrDefaultNameservers(path string, defaults []string) []string { + config, err := dns.ClientConfigFromFile(path) + if err != nil || len(config.Servers) == 0 { + return defaults + } + return config.Servers +} + +// populateNameserverPorts ensures that all nameservers have a port number. +func populateNameserverPorts(servers []string) { + for i := range servers { + _, port, _ := net.SplitHostPort(servers[i]) + if port == "" { + servers[i] = net.JoinHostPort(servers[i], "53") + } + } +} + +// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers. +func checkDNSPropagation(fqdn, value string, resolvers []string) (bool, error) { + if !strings.HasSuffix(fqdn, ".") { + fqdn += "." + } + + // Initial attempt to resolve at the recursive NS + r, err := dnsQuery(fqdn, dns.TypeTXT, resolvers, true) + if err != nil { + return false, err + } + + // TODO: make this configurable, maybe + // if !p.requireCompletePropagation { + // return true, nil + // } + + if r.Rcode == dns.RcodeSuccess { + fqdn = updateDomainWithCName(r, fqdn) + } + + authoritativeNss, err := lookupNameservers(fqdn, resolvers) + if err != nil { + return false, err + } + + return checkAuthoritativeNss(fqdn, value, authoritativeNss) +} + +// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record. +func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) { + for _, ns := range nameservers { + r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false) + if err != nil { + return false, err + } + + if r.Rcode != dns.RcodeSuccess { + if r.Rcode == dns.RcodeNameError { + // if Present() succeeded, then it must show up eventually, or else + // something is really broken in the DNS provider or their API; + // no need for error here, simply have the caller try again + return false, nil + } + return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn) + } + + var found bool + for _, rr := range r.Answer { + if txt, ok := rr.(*dns.TXT); ok { + record := strings.Join(txt.Txt, "") + if record == value { + found = true + break + } + } + } + + if !found { + return false, nil + } + } + + return true, nil +} + +// lookupNameservers returns the authoritative nameservers for the given fqdn. +func lookupNameservers(fqdn string, resolvers []string) ([]string, error) { + var authoritativeNss []string + + zone, err := findZoneByFQDN(fqdn, resolvers) + if err != nil { + return nil, fmt.Errorf("could not determine the zone: %w", err) + } + + r, err := dnsQuery(zone, dns.TypeNS, resolvers, true) + if err != nil { + return nil, err + } + + for _, rr := range r.Answer { + if ns, ok := rr.(*dns.NS); ok { + authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns)) + } + } + + if len(authoritativeNss) > 0 { + return authoritativeNss, nil + } + return nil, errors.New("could not determine authoritative nameservers") +} + +// Update FQDN with CNAME if any +func updateDomainWithCName(r *dns.Msg, fqdn string) string { + for _, rr := range r.Answer { + if cn, ok := rr.(*dns.CNAME); ok { + if cn.Hdr.Name == fqdn { + return cn.Target + } + } + } + return fqdn +} + +// recursiveNameservers are used to pre-check DNS propagation. It +// prepends user-configured nameservers (custom) to the defaults +// obtained from resolv.conf and defaultNameservers and ensures +// that all server addresses have a port value. +func recursiveNameservers(custom []string) []string { + servers := append(custom, systemOrDefaultNameservers(defaultResolvConf, defaultNameservers)...) + populateNameserverPorts(servers) + return servers +} + +var defaultNameservers = []string{ + "8.8.8.8:53", + "8.8.4.4:53", + "1.1.1.1:53", + "1.0.0.1:53", +} + +var dnsTimeout = 10 * time.Second + +var ( + fqdnSOACache = map[string]*soaCacheEntry{} + fqdnSOACacheMu sync.Mutex +) + +const defaultResolvConf = "/etc/resolv.conf" |