diff options
| author | radhitya <alif@radhitya.org> | 2026-06-13 16:09:53 +0700 |
|---|---|---|
| committer | radhitya <alif@radhitya.org> | 2026-06-13 16:09:53 +0700 |
| commit | 3e44adc94f32bfe500730fcbf1c02cedf65b0a30 (patch) | |
| tree | 66932e0f386ba1277506e9d1fb18eaaad70bfef3 /internal/resolver/resolver.go | |
| parent | d802d4a685016be8b79c89b4f21099b9a1569532 (diff) | |
root hints, glue record, delegation loop, iterative, ns fallback, timeout, glue record
Diffstat (limited to 'internal/resolver/resolver.go')
| -rw-r--r-- | internal/resolver/resolver.go | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go new file mode 100644 index 0000000..4ad023a --- /dev/null +++ b/internal/resolver/resolver.go @@ -0,0 +1,192 @@ +package resolver + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/miekg/dns" +) + +var ( + ErrMaxDelegations = errors.New("max delegations exceeded") + ErrNoServers = errors.New("no nameservers available") +) + +type Resolver struct { + roots []string + maxDelegations int + timeout time.Duration + retries int +} + +type Option func(*Resolver) + +func New(opts ...Option) *Resolver { + r := &Resolver{ + roots: loadRootServers(), + maxDelegations: 30, + timeout: 2 * time.Second, + retries: 2, + } + for _, opt := range opts { + opt(r) + } + return r +} + +func WithRootAddresses(addrs []string) Option { + return func(r *Resolver) { + r.roots = addrs + } +} + +func WithTimeout(d time.Duration) Option { + return func(r *Resolver) { + r.timeout = d + } +} + +func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { + if ctx == nil { + ctx = context.Background() + } + return r.resolve(ctx, qname, qtype) +} + +func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { + servers := make([]string, len(r.roots)) + copy(servers, r.roots) + for depth := 0; depth < r.maxDelegations; depth++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if len(servers) == 0 { + return nil, ErrNoServers + } + reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype) + if err != nil { + return nil, fmt.Errorf("resolve %s %s: %w", + qname, dns.TypeToString[qtype], err) + } + switch { + case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: + return reply, nil + case reply.Rcode == dns.RcodeNameError: + return reply, nil + case reply.Rcode == dns.RcodeSuccess && isReferral(reply): + next, err := r.nextServers(ctx, reply) + if err != nil { + return nil, err + } + servers = next + default: + if len(servers) > 1 { + servers = servers[1:] + } else { + return reply, nil + } + } + + } + return nil, ErrMaxDelegations +} + +func isReferral(msg *dns.Msg) bool { + return !msg.Authoritative && len(msg.Ns) > 0 +} + +func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, error) { + var targets []string + glue := make(map[string]string) + + for _, rr := range msg.Ns { + ns, ok := rr.(*dns.NS) + if !ok { + continue + } + targets = append(targets, ns.Ns) + } + for _, rr := range msg.Extra { + switch v := rr.(type) { + case *dns.A: + if _, exists := glue[v.Hdr.Name]; !exists { + glue[v.Hdr.Name] = v.A.String() + } + case *dns.AAAA: + if _, exists := glue[v.Hdr.Name]; !exists { + glue[v.Hdr.Name] = v.AAAA.String() + } + } + } + + var addrs []string + var unresolved []string + for _, t := range targets { + if ip, ok := glue[t]; ok { + addrs = append(addrs, ip) + } else { + unresolved = append(unresolved, t) + } + } + + if len(addrs) > 0 { + return addrs, nil + } + + for _, name := range unresolved { + reply, err := r.resolve(ctx, name, dns.TypeA) + if err != nil { + continue + } + if reply.Rcode != dns.RcodeSuccess { + continue + } + for _, rr := range reply.Answer { + if a, ok := rr.(*dns.A); ok { + addrs = append(addrs, a.A.String()) + break + } + } + } + + if len(addrs) == 0 { + return nil, ErrNoServers + } + return addrs, nil +} + +func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, + qname string, qtype uint16) (*dns.Msg, error) { + msg := new(dns.Msg) + msg.SetQuestion(qname, qtype) + msg.SetEdns0(4096, false) + msg.RecursionDesired = false + + client := &dns.Client{ + Net: "udp", + UDPSize: 4096, + Timeout: r.timeout, + } + + var lastErr error + for _, srv := range servers { + addr := srv + if _, _, err := net.SplitHostPort(srv); err != nil { + addr = net.JoinHostPort(srv, "53") + } + for attempt := 0; attempt < r.retries; attempt++ { + reply, _, err := client.ExchangeContext(ctx, msg, addr) + if err == nil { + return reply, nil + } + lastErr = err + time.Sleep(time.Duration(attempt+1) * 200 * time.Millisecond) + } + } + return nil, lastErr +} |
