From 6a8d30db1bf649514c7266114ab5b2aa89e1629e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Ruijs Date: Mon, 15 Jul 2024 09:17:33 +0000 Subject: [PATCH] fix: dns rfc compliant ttl --- discovery/dns/dns.go | 125 ++++++++++++++++++++++++++++++------------- 1 file changed, 87 insertions(+), 38 deletions(-) diff --git a/discovery/dns/dns.go b/discovery/dns/dns.go index 314c3d38cd..6bf64bd316 100644 --- a/discovery/dns/dns.go +++ b/discovery/dns/dns.go @@ -50,6 +50,57 @@ const ( namespace = "prometheus" ) +// DNSCache stores DNS responses based on their TTL values. +type DNSCache struct { + mu sync.RWMutex + cache map[string]*dnsCacheEntry +} + +type dnsCacheEntry struct { + msg *dns.Msg + expiry time.Time +} + +func NewDNSCache() *DNSCache { + return &DNSCache{ + cache: make(map[string]*dnsCacheEntry), + } +} + +func (c *DNSCache) Get(name string) (*dns.Msg, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, ok := c.cache[name] + if !ok || time.Now().After(entry.expiry) { + return nil, false + } + return entry.msg, true +} + +func (c *DNSCache) Set(name string, msg *dns.Msg) { + c.mu.Lock() + defer c.mu.Unlock() + + ttl := getMinimumTTL(msg) + entry := &dnsCacheEntry{ + msg: msg, + expiry: time.Now().Add(ttl), + } + c.cache[name] = entry +} + +func getMinimumTTL(msg *dns.Msg) time.Duration { + minTTL := uint32(^uint32(0)) // Set to maximum value + for _, record := range append(msg.Answer, msg.Ns...) { + hdr := record.Header() + if hdr.Ttl < minTTL { + minTTL = hdr.Ttl + } + } + return time.Duration(minTTL) * time.Second +} + // DefaultSDConfig is the default DNS SD configuration. var DefaultSDConfig = SDConfig{ RefreshInterval: model.Duration(30 * time.Second), @@ -68,20 +119,16 @@ type SDConfig struct { Port int `yaml:"port"` // Ignored for SRV records } -// NewDiscovererMetrics implements discovery.Config. func (*SDConfig) NewDiscovererMetrics(reg prometheus.Registerer, rmi discovery.RefreshMetricsInstantiator) discovery.DiscovererMetrics { return newDiscovererMetrics(reg, rmi) } -// Name returns the name of the Config. func (*SDConfig) Name() string { return "dns" } -// NewDiscoverer returns a Discoverer for the Config. func (c *SDConfig) NewDiscoverer(opts discovery.DiscovererOptions) (discovery.Discoverer, error) { return NewDiscovery(*c, opts.Logger, opts.Metrics) } -// UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *SDConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = DefaultSDConfig type plain SDConfig @@ -113,6 +160,7 @@ type Discovery struct { qtype uint16 logger log.Logger metrics *dnsMetrics + cache *DNSCache lookupFn func(name string, qtype uint16, logger log.Logger) (*dns.Msg, error) } @@ -148,6 +196,7 @@ func NewDiscovery(conf SDConfig, logger log.Logger, metrics discovery.Discoverer logger: logger, lookupFn: lookupWithSearchPath, metrics: m, + cache: NewDNSCache(), } d.Discovery = refresh.NewDiscovery( @@ -192,6 +241,22 @@ func (d *Discovery) refresh(ctx context.Context) ([]*targetgroup.Group, error) { } func (d *Discovery) refreshOne(ctx context.Context, name string, ch chan<- *targetgroup.Group) error { + cachedMsg, found := d.cache.Get(name) + if found { + d.metrics.dnsSDLookupsCount.Inc() + tg, err := d.dnsMsgToTargetGroup(ctx, name, cachedMsg) + if err != nil { + d.metrics.dnsSDLookupFailuresCount.Inc() + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case ch <- tg: + } + return nil + } + response, err := d.lookupFn(name, d.qtype, d.logger) d.metrics.dnsSDLookupsCount.Inc() if err != nil { @@ -199,12 +264,27 @@ func (d *Discovery) refreshOne(ctx context.Context, name string, ch chan<- *targ return err } + d.cache.Set(name, response) + tg, err := d.dnsMsgToTargetGroup(ctx, name, response) + if err != nil { + d.metrics.dnsSDLookupFailuresCount.Inc() + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case ch <- tg: + } + return nil +} + +func (d *Discovery) dnsMsgToTargetGroup(ctx context.Context, name string, msg *dns.Msg) (*targetgroup.Group, error) { tg := &targetgroup.Group{} hostPort := func(a string, p int) model.LabelValue { return model.LabelValue(net.JoinHostPort(a, strconv.Itoa(p))) } - for _, record := range response.Answer { + for _, record := range msg.Answer { var target, dnsSrvRecordTarget, dnsSrvRecordPort, dnsMxRecordTarget, dnsNsRecordTarget model.LabelValue switch addr := record.(type) { @@ -212,21 +292,18 @@ func (d *Discovery) refreshOne(ctx context.Context, name string, ch chan<- *targ dnsSrvRecordTarget = model.LabelValue(addr.Target) dnsSrvRecordPort = model.LabelValue(strconv.Itoa(int(addr.Port))) - // Remove the final dot from rooted DNS names to make them look more usual. addr.Target = strings.TrimRight(addr.Target, ".") target = hostPort(addr.Target, int(addr.Port)) case *dns.MX: dnsMxRecordTarget = model.LabelValue(addr.Mx) - // Remove the final dot from rooted DNS names to make them look more usual. addr.Mx = strings.TrimRight(addr.Mx, ".") target = hostPort(addr.Mx, d.port) case *dns.NS: dnsNsRecordTarget = model.LabelValue(addr.Ns) - // Remove the final dot from rooted DNS names to make them look more usual. addr.Ns = strings.TrimRight(addr.Ns, ".") target = hostPort(addr.Ns, d.port) @@ -235,7 +312,6 @@ func (d *Discovery) refreshOne(ctx context.Context, name string, ch chan<- *targ case *dns.AAAA: target = hostPort(addr.AAAA.String(), d.port) case *dns.CNAME: - // CNAME responses can occur with "Type: A" dns_sd_config requests. continue default: level.Warn(d.logger).Log("msg", "Invalid record", "record", record) @@ -252,13 +328,7 @@ func (d *Discovery) refreshOne(ctx context.Context, name string, ch chan<- *targ } tg.Source = name - select { - case <-ctx.Done(): - return ctx.Err() - case ch <- tg: - } - - return nil + return tg, nil } // lookupWithSearchPath tries to get an answer for various permutations of @@ -301,23 +371,15 @@ func lookupWithSearchPath(name string, qtype uint16, logger log.Logger) (*dns.Ms switch { case err != nil: - // We can't go home yet, because a later name - // may give us a valid, successful answer. However - // we can no longer say "this name definitely doesn't - // exist", because we did not get that answer for - // at least one name. allResponsesValid = false case response.Rcode == dns.RcodeSuccess: - // Outcome 1: GOLD! return response, nil } } if allResponsesValid { - // Outcome 2: everyone says NXDOMAIN, that's good enough for me. return &dns.Msg{}, nil } - // Outcome 3: boned. return nil, fmt.Errorf("could not resolve %q: all servers responded with errors to at least one search domain", name) } @@ -325,18 +387,6 @@ func lookupWithSearchPath(name string, qtype uint16, logger log.Logger) (*dns.Ms // name. If a viable answer is received from a server, then it is // immediately returned, otherwise the other servers in the config are // tried, and if none of them return a viable answer, an error is returned. -// -// A "viable answer" is one which indicates either: -// -// 1. "yes, I know that name, and here are its records of the requested type" -// (RCODE==SUCCESS, ANCOUNT > 0); -// 2. "yes, I know that name, but it has no records of the requested type" -// (RCODE==SUCCESS, ANCOUNT==0); or -// 3. "I know that name doesn't exist" (RCODE==NXDOMAIN). -// -// A non-viable answer is "anything else", which encompasses both various -// system-level problems (like network timeouts) and also -// valid-but-unexpected DNS responses (SERVFAIL, REFUSED, etc). func lookupFromAnyServer(name string, qtype uint16, conf *dns.ClientConfig, logger log.Logger) (*dns.Msg, error) { client := &dns.Client{} @@ -349,7 +399,6 @@ func lookupFromAnyServer(name string, qtype uint16, conf *dns.ClientConfig, logg } if msg.Rcode == dns.RcodeSuccess || msg.Rcode == dns.RcodeNameError { - // We have our answer. Time to go home. return msg, nil } } @@ -384,4 +433,4 @@ func askServerForName(name string, queryType uint16, client *dns.Client, servAdd } return response, nil -} +} \ No newline at end of file