From f5753c6a8cac5a57a042b0388f38abeff5d1f37d Mon Sep 17 00:00:00 2001 From: radhitya Date: Thu, 18 Jun 2026 12:42:29 +0700 Subject: migration to new dns library --- internal/resolver/resolver.go | 91 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 74 insertions(+), 17 deletions(-) (limited to 'internal/resolver/resolver.go') diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 9ee9cd6..5aa7bc1 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -7,7 +7,7 @@ import ( "net" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" ) var ( @@ -26,16 +26,15 @@ type Resolver struct { type Option func(*Resolver) func New(opts ...Option) *Resolver { + transport := dns.NewTransport() + transport.ReadTimeout = 2 * time.Second + r := &Resolver{ roots: loadRootServers(), maxDelegations: 30, timeout: 2 * time.Second, retries: 2, - client: &dns.Client{ - Net: "udp", - UDPSize: 4096, - Timeout: 2 * time.Second, - }, + client: &dns.Client{Transport: transport}, } for _, opt := range opts { opt(r) @@ -52,6 +51,7 @@ func WithRootAddresses(addrs []string) Option { func WithTimeout(d time.Duration) Option { return func(r *Resolver) { r.timeout = d + r.client.Transport.ReadTimeout = d } } @@ -81,6 +81,9 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: + if needsCNAMEResolution(reply, qtype) { + return r.resolveCNAME(ctx, reply, qtype) + } return reply, nil case reply.Rcode == dns.RcodeNameError: return reply, nil @@ -104,6 +107,62 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn return nil, ErrMaxDelegations } +func needsCNAMEResolution(reply *dns.Msg, qtype uint16) bool { + hasCNAME := false + hasTarget := false + for _, rr := range reply.Answer { + if _, ok := rr.(*dns.CNAME); ok { + hasCNAME = true + } + if dns.RRToType(rr) == qtype && dns.RRToType(rr) != dns.TypeCNAME { + hasTarget = true + } + } + return hasCNAME && !hasTarget +} +func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint16) (*dns.Msg, error) { + var target string + for _, rr := range reply.Answer { + if cn, ok := rr.(*dns.CNAME); ok { + target = cn.Target + } + } + if target == "" { + return reply, nil + } + + const maxChain = 10 + seen := map[string]bool{reply.Question[0].Header().Name: true} + for i := 0; i < maxChain; i++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if seen[target] { + break + } + seen[target] = true + + chainReply, err := r.resolve(ctx, target, qtype) + if err != nil { + return reply, nil + } + reply.Answer = append(reply.Answer, chainReply.Answer...) + + nextTarget := "" + for _, rr := range chainReply.Answer { + if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target { + nextTarget = cn.Target + } + } + if nextTarget == "" { + break + } + target = nextTarget + } + return reply, nil +} func isReferral(msg *dns.Msg) bool { return !msg.Authoritative && len(msg.Ns) > 0 } @@ -122,12 +181,12 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err for _, rr := range msg.Extra { switch v := rr.(type) { case *dns.A: - if _, exists := glue[v.Hdr.Name]; !exists { - glue[v.Hdr.Name] = v.A.String() + if _, exists := glue[v.Header().Name]; !exists { + glue[v.Header().Name] = v.A.Addr.String() } case *dns.AAAA: - if _, exists := glue[v.Hdr.Name]; !exists { - glue[v.Hdr.Name] = v.AAAA.String() + if _, exists := glue[v.Header().Name]; !exists { + glue[v.Header().Name] = v.AAAA.Addr.String() } } } @@ -156,7 +215,7 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err } for _, rr := range reply.Answer { if a, ok := rr.(*dns.A); ok { - addrs = append(addrs, a.A.String()) + addrs = append(addrs, a.A.Addr.String()) break } } @@ -171,9 +230,8 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, qname string, qtype uint16) (*dns.Msg, error) { - msg := new(dns.Msg) - msg.SetQuestion(qname, qtype) - msg.SetEdns0(4096, false) + msg := dns.NewMsg(qname, qtype) + msg.UDPSize = 4096 msg.RecursionDesired = false type result struct { @@ -189,11 +247,10 @@ func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, } go func(addr string) { for attempt := 0; attempt < r.retries; attempt++ { - reply, _, err := r.client.ExchangeContext(ctx, msg, addr) + reply, _, err := r.client.Exchange(ctx, msg, "udp", addr) if err == nil { if reply.Truncated { - tcpClient := &dns.Client{Net: "tcp", Timeout: r.timeout} - reply, _, err = tcpClient.ExchangeContext(ctx, msg, addr) + reply, _, err = r.client.Exchange(ctx, msg, "tcp", addr) if err == nil { ch <- result{reply: reply, err: nil} return -- cgit v1.2.3