package resolver import ( "context" "errors" "fmt" "net" "time" "codeberg.org/miekg/dns" ) var ( ErrMaxDelegations = errors.New("max delegations exceeded") ErrNoServers = errors.New("no nameservers available") ) type Resolver struct { roots []string maxDelegations int timeout time.Duration retries int forwarders []string client *dns.Client } type Option func(*Resolver) func New(opts ...Option) *Resolver { transport := dns.NewTransport() transport.ReadTimeout = 2 * time.Second r := &Resolver{ roots: loadRootServers(), maxDelegations: 30, timeout: 2 * time.Second, retries: 2, client: &dns.Client{Transport: transport}, } for _, opt := range opts { opt(r) } return r } func WithForwarders(addrs []string) Option { return func(r *Resolver) { r.forwarders = addrs } } func WithRootAddresses(addrs []string) Option { return func(r *Resolver) { r.roots = addrs } } func WithTimeout(d time.Duration) Option { return func(r *Resolver) { r.timeout = d r.client.Transport.ReadTimeout = d } } func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { if ctx == nil { ctx = context.Background() } if len(r.forwarders) > 0 { return r.forward(ctx, qname, qtype) } return r.resolve(ctx, qname, qtype) } func (r *Resolver) forward(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { servers := make([]string, len(r.forwarders)) copy(servers, r.forwarders) reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, true) if err != nil { return nil, fmt.Errorf("forward %s %s: %w", qname, dns.TypeToString[qtype],err) } return reply, nil } func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { servers := make([]string, len(r.roots)) copy(servers, r.roots) for depth := 0; depth < r.maxDelegations; depth++ { select { case <-ctx.Done(): return nil, ctx.Err() default: } if len(servers) == 0 { return nil, ErrNoServers } reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, false) if err != nil { return nil, fmt.Errorf("resolve %s %s: %w", qname, dns.TypeToString[qtype], err) } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: if needsCNAMEResolution(reply, qtype) { return r.resolveCNAME(ctx, reply, qtype) } return reply, nil case reply.Rcode == dns.RcodeNameError: return reply, nil case reply.Rcode == dns.RcodeSuccess && isReferral(reply): next, err := r.nextServers(ctx, reply) if err != nil { return nil, err } servers = next case reply.Rcode == dns.RcodeSuccess: return reply, nil default: if len(servers) > 1 { servers = servers[1:] } else { return reply, nil } } } return nil, ErrMaxDelegations } func needsCNAMEResolution(reply *dns.Msg, qtype uint16) bool { hasCNAME := false hasTarget := false for _, rr := range reply.Answer { if _, ok := rr.(*dns.CNAME); ok { hasCNAME = true } if dns.RRToType(rr) == qtype && dns.RRToType(rr) != dns.TypeCNAME { hasTarget = true } } return hasCNAME && !hasTarget } func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint16) (*dns.Msg, error) { var target string for _, rr := range reply.Answer { if cn, ok := rr.(*dns.CNAME); ok { target = cn.Target } } if target == "" { return reply, nil } const maxChain = 10 seen := map[string]bool{reply.Question[0].Header().Name: true} for i := 0; i < maxChain; i++ { select { case <-ctx.Done(): return nil, ctx.Err() default: } if seen[target] { break } seen[target] = true chainReply, err := r.resolve(ctx, target, qtype) if err != nil { return reply, nil } reply.Answer = append(reply.Answer, chainReply.Answer...) nextTarget := "" for _, rr := range chainReply.Answer { if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target { nextTarget = cn.Target } } if nextTarget == "" { break } target = nextTarget } return reply, nil } func isReferral(msg *dns.Msg) bool { return !msg.Authoritative && len(msg.Ns) > 0 } func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, error) { var targets []string glue := make(map[string]string) for _, rr := range msg.Ns { ns, ok := rr.(*dns.NS) if !ok { continue } targets = append(targets, ns.Ns) } for _, rr := range msg.Extra { switch v := rr.(type) { case *dns.A: if _, exists := glue[v.Header().Name]; !exists { glue[v.Header().Name] = v.A.Addr.String() } case *dns.AAAA: if _, exists := glue[v.Header().Name]; !exists { glue[v.Header().Name] = v.AAAA.Addr.String() } } } var addrs []string var unresolved []string for _, t := range targets { if ip, ok := glue[t]; ok { addrs = append(addrs, ip) } else { unresolved = append(unresolved, t) } } if len(addrs) > 0 { return addrs, nil } for _, name := range unresolved { reply, err := r.resolve(ctx, name, dns.TypeA) if err != nil { continue } if reply.Rcode != dns.RcodeSuccess { continue } for _, rr := range reply.Answer { if a, ok := rr.(*dns.A); ok { addrs = append(addrs, a.A.Addr.String()) break } } } if len(addrs) == 0 { return nil, ErrNoServers } return addrs, nil } func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, qname string, qtype uint16, rd bool) (*dns.Msg, error) { msg := dns.NewMsg(qname, qtype) msg.UDPSize = 4096 msg.RecursionDesired = rd type result struct { reply *dns.Msg err error } ch := make(chan result, len(servers)) for _, srv := range servers { addr := srv if _, _, err := net.SplitHostPort(srv); err != nil { addr = net.JoinHostPort(srv, "53") } go func(addr string) { for attempt := 0; attempt < r.retries; attempt++ { reply, _, err := r.client.Exchange(ctx, msg, "udp", addr) if err == nil { if reply.Truncated { reply, _, err = r.client.Exchange(ctx, msg, "tcp", addr) if err == nil { ch <- result{reply: reply, err: nil} return } } ch <- result{reply: reply, err: nil} return } delay := time.Duration(1<