package resolver import ( "context" "errors" "fmt" "net" "time" "github.com/miekg/dns" ) var ( ErrMaxDelegations = errors.New("max delegations exceeded") ErrNoServers = errors.New("no nameservers available") ) type Resolver struct { roots []string maxDelegations int timeout time.Duration retries int } type Option func(*Resolver) func New(opts ...Option) *Resolver { r := &Resolver{ roots: loadRootServers(), maxDelegations: 30, timeout: 2 * time.Second, retries: 2, } for _, opt := range opts { opt(r) } return r } func WithRootAddresses(addrs []string) Option { return func(r *Resolver) { r.roots = addrs } } func WithTimeout(d time.Duration) Option { return func(r *Resolver) { r.timeout = d } } func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { if ctx == nil { ctx = context.Background() } return r.resolve(ctx, qname, qtype) } func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { servers := make([]string, len(r.roots)) copy(servers, r.roots) for depth := 0; depth < r.maxDelegations; depth++ { select { case <-ctx.Done(): return nil, ctx.Err() default: } if len(servers) == 0 { return nil, ErrNoServers } reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype) if err != nil { return nil, fmt.Errorf("resolve %s %s: %w", qname, dns.TypeToString[qtype], err) } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: return reply, nil case reply.Rcode == dns.RcodeNameError: return reply, nil case reply.Rcode == dns.RcodeSuccess && isReferral(reply): next, err := r.nextServers(ctx, reply) if err != nil { return nil, err } servers = next default: if len(servers) > 1 { servers = servers[1:] } else { return reply, nil } } } return nil, ErrMaxDelegations } func isReferral(msg *dns.Msg) bool { return !msg.Authoritative && len(msg.Ns) > 0 } func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, error) { var targets []string glue := make(map[string]string) for _, rr := range msg.Ns { ns, ok := rr.(*dns.NS) if !ok { continue } targets = append(targets, ns.Ns) } for _, rr := range msg.Extra { switch v := rr.(type) { case *dns.A: if _, exists := glue[v.Hdr.Name]; !exists { glue[v.Hdr.Name] = v.A.String() } case *dns.AAAA: if _, exists := glue[v.Hdr.Name]; !exists { glue[v.Hdr.Name] = v.AAAA.String() } } } var addrs []string var unresolved []string for _, t := range targets { if ip, ok := glue[t]; ok { addrs = append(addrs, ip) } else { unresolved = append(unresolved, t) } } if len(addrs) > 0 { return addrs, nil } for _, name := range unresolved { reply, err := r.resolve(ctx, name, dns.TypeA) if err != nil { continue } if reply.Rcode != dns.RcodeSuccess { continue } for _, rr := range reply.Answer { if a, ok := rr.(*dns.A); ok { addrs = append(addrs, a.A.String()) break } } } if len(addrs) == 0 { return nil, ErrNoServers } return addrs, nil } func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, qname string, qtype uint16) (*dns.Msg, error) { msg := new(dns.Msg) msg.SetQuestion(qname, qtype) msg.SetEdns0(4096, false) msg.RecursionDesired = false client := &dns.Client{ Net: "udp", UDPSize: 4096, Timeout: r.timeout, } var lastErr error for _, srv := range servers { addr := srv if _, _, err := net.SplitHostPort(srv); err != nil { addr = net.JoinHostPort(srv, "53") } for attempt := 0; attempt < r.retries; attempt++ { reply, _, err := client.ExchangeContext(ctx, msg, addr) if err == nil { return reply, nil } lastErr = err time.Sleep(time.Duration(attempt+1) * 200 * time.Millisecond) } } return nil, lastErr }