diff options
| author | radhitya <alif@radhitya.org> | 2026-06-14 18:48:11 +0700 |
|---|---|---|
| committer | radhitya <alif@radhitya.org> | 2026-06-14 18:48:11 +0700 |
| commit | a8df5f3b6d18371b8c4cc98d01f4004586a50b21 (patch) | |
| tree | 0e8e384d4fcebdc0034bf8884f51dccfedc53b65 /internal/resolver | |
| parent | d173554892339e5211020c60d6af610840eef7ed (diff) | |
fix iterative dns resolver, tcp fallback
Diffstat (limited to 'internal/resolver')
| -rw-r--r-- | internal/resolver/resolver.go | 55 |
1 files changed, 44 insertions, 11 deletions
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 4ad023a..9ee9cd6 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -20,6 +20,7 @@ type Resolver struct { maxDelegations int timeout time.Duration retries int + client *dns.Client } type Option func(*Resolver) @@ -30,6 +31,11 @@ func New(opts ...Option) *Resolver { maxDelegations: 30, timeout: 2 * time.Second, retries: 2, + client: &dns.Client{ + Net: "udp", + UDPSize: 4096, + Timeout: 2 * time.Second, + }, } for _, opt := range opts { opt(r) @@ -84,6 +90,8 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn return nil, err } servers = next + case reply.Rcode == dns.RcodeSuccess: + return reply, nil default: if len(servers) > 1 { servers = servers[1:] @@ -162,31 +170,56 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, qname string, qtype uint16) (*dns.Msg, error) { + msg := new(dns.Msg) msg.SetQuestion(qname, qtype) msg.SetEdns0(4096, false) msg.RecursionDesired = false - client := &dns.Client{ - Net: "udp", - UDPSize: 4096, - Timeout: r.timeout, + type result struct { + reply *dns.Msg + err error } - var lastErr error + ch := make(chan result, len(servers)) for _, srv := range servers { addr := srv if _, _, err := net.SplitHostPort(srv); err != nil { addr = net.JoinHostPort(srv, "53") } - for attempt := 0; attempt < r.retries; attempt++ { - reply, _, err := client.ExchangeContext(ctx, msg, addr) - if err == nil { - return reply, nil + go func(addr string) { + for attempt := 0; attempt < r.retries; attempt++ { + reply, _, err := r.client.ExchangeContext(ctx, msg, addr) + if err == nil { + if reply.Truncated { + tcpClient := &dns.Client{Net: "tcp", Timeout: r.timeout} + reply, _, err = tcpClient.ExchangeContext(ctx, msg, addr) + if err == nil { + ch <- result{reply: reply, err: nil} + return + } + } + ch <- result{reply: reply, err: nil} + return + } + delay := time.Duration(1<<attempt) * 100 * time.Millisecond + select { + case <-time.After(delay): + case <-ctx.Done(): + ch <- result{nil, ctx.Err()} + return + } } - lastErr = err - time.Sleep(time.Duration(attempt+1) * 200 * time.Millisecond) + ch <- result{nil, fmt.Errorf("all retries failed for %s", addr)} + }(addr) + } + var lastErr error + for i := 0; i < len(servers); i++ { + res := <-ch + if res.err == nil { + return res.reply, nil } + lastErr = res.err } return nil, lastErr } |
