package resolver import ( "context" "errors" "fmt" "net" "time" "github.com/miekg/dns" ) var ( ErrMaxDelegations = errors.New("max delegations exceeded") ErrNoServers = errors.New("no nameservers available") ) type Resolver struct { roots []string maxDelegations int timeout time.Duration retries int client *dns.Client } type Option func(*Resolver) func New(opts ...Option) *Resolver { r := &Resolver{ roots: loadRootServers(), maxDelegations: 30, timeout: 2 * time.Second, retries: 2, client: &dns.Client{ Net: "udp", UDPSize: 4096, Timeout: 2 * time.Second, }, } for _, opt := range opts { opt(r) } return r } func WithRootAddresses(addrs []string) Option { return func(r *Resolver) { r.roots = addrs } } func WithTimeout(d time.Duration) Option { return func(r *Resolver) { r.timeout = d } } func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { if ctx == nil { ctx = context.Background() } return r.resolve(ctx, qname, qtype) } func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { servers := make([]string, len(r.roots)) copy(servers, r.roots) for depth := 0; depth < r.maxDelegations; depth++ { select { case <-ctx.Done(): return nil, ctx.Err() default: } if len(servers) == 0 { return nil, ErrNoServers } reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype) if err != nil { return nil, fmt.Errorf("resolve %s %s: %w", qname, dns.TypeToString[qtype], err) } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: return reply, nil case reply.Rcode == dns.RcodeNameError: return reply, nil case reply.Rcode == dns.RcodeSuccess && isReferral(reply): next, err := r.nextServers(ctx, reply) if err != nil { return nil, err } servers = next case reply.Rcode == dns.RcodeSuccess: return reply, nil default: if len(servers) > 1 { servers = servers[1:] } else { return reply, nil } } } return nil, ErrMaxDelegations } func isReferral(msg *dns.Msg) bool { return !msg.Authoritative && len(msg.Ns) > 0 } func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, error) { var targets []string glue := make(map[string]string) for _, rr := range msg.Ns { ns, ok := rr.(*dns.NS) if !ok { continue } targets = append(targets, ns.Ns) } for _, rr := range msg.Extra { switch v := rr.(type) { case *dns.A: if _, exists := glue[v.Hdr.Name]; !exists { glue[v.Hdr.Name] = v.A.String() } case *dns.AAAA: if _, exists := glue[v.Hdr.Name]; !exists { glue[v.Hdr.Name] = v.AAAA.String() } } } var addrs []string var unresolved []string for _, t := range targets { if ip, ok := glue[t]; ok { addrs = append(addrs, ip) } else { unresolved = append(unresolved, t) } } if len(addrs) > 0 { return addrs, nil } for _, name := range unresolved { reply, err := r.resolve(ctx, name, dns.TypeA) if err != nil { continue } if reply.Rcode != dns.RcodeSuccess { continue } for _, rr := range reply.Answer { if a, ok := rr.(*dns.A); ok { addrs = append(addrs, a.A.String()) break } } } if len(addrs) == 0 { return nil, ErrNoServers } return addrs, nil } func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, qname string, qtype uint16) (*dns.Msg, error) { msg := new(dns.Msg) msg.SetQuestion(qname, qtype) msg.SetEdns0(4096, false) msg.RecursionDesired = false type result struct { reply *dns.Msg err error } ch := make(chan result, len(servers)) for _, srv := range servers { addr := srv if _, _, err := net.SplitHostPort(srv); err != nil { addr = net.JoinHostPort(srv, "53") } go func(addr string) { for attempt := 0; attempt < r.retries; attempt++ { reply, _, err := r.client.ExchangeContext(ctx, msg, addr) if err == nil { if reply.Truncated { tcpClient := &dns.Client{Net: "tcp", Timeout: r.timeout} reply, _, err = tcpClient.ExchangeContext(ctx, msg, addr) if err == nil { ch <- result{reply: reply, err: nil} return } } ch <- result{reply: reply, err: nil} return } delay := time.Duration(1<