summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/resolver/resolver.go55
1 files changed, 44 insertions, 11 deletions
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go
index 4ad023a..9ee9cd6 100644
--- a/internal/resolver/resolver.go
+++ b/internal/resolver/resolver.go
@@ -20,6 +20,7 @@ type Resolver struct {
maxDelegations int
timeout time.Duration
retries int
+ client *dns.Client
}
type Option func(*Resolver)
@@ -30,6 +31,11 @@ func New(opts ...Option) *Resolver {
maxDelegations: 30,
timeout: 2 * time.Second,
retries: 2,
+ client: &dns.Client{
+ Net: "udp",
+ UDPSize: 4096,
+ Timeout: 2 * time.Second,
+ },
}
for _, opt := range opts {
opt(r)
@@ -84,6 +90,8 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
return nil, err
}
servers = next
+ case reply.Rcode == dns.RcodeSuccess:
+ return reply, nil
default:
if len(servers) > 1 {
servers = servers[1:]
@@ -162,31 +170,56 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
qname string, qtype uint16) (*dns.Msg, error) {
+
msg := new(dns.Msg)
msg.SetQuestion(qname, qtype)
msg.SetEdns0(4096, false)
msg.RecursionDesired = false
- client := &dns.Client{
- Net: "udp",
- UDPSize: 4096,
- Timeout: r.timeout,
+ type result struct {
+ reply *dns.Msg
+ err error
}
- var lastErr error
+ ch := make(chan result, len(servers))
for _, srv := range servers {
addr := srv
if _, _, err := net.SplitHostPort(srv); err != nil {
addr = net.JoinHostPort(srv, "53")
}
- for attempt := 0; attempt < r.retries; attempt++ {
- reply, _, err := client.ExchangeContext(ctx, msg, addr)
- if err == nil {
- return reply, nil
+ go func(addr string) {
+ for attempt := 0; attempt < r.retries; attempt++ {
+ reply, _, err := r.client.ExchangeContext(ctx, msg, addr)
+ if err == nil {
+ if reply.Truncated {
+ tcpClient := &dns.Client{Net: "tcp", Timeout: r.timeout}
+ reply, _, err = tcpClient.ExchangeContext(ctx, msg, addr)
+ if err == nil {
+ ch <- result{reply: reply, err: nil}
+ return
+ }
+ }
+ ch <- result{reply: reply, err: nil}
+ return
+ }
+ delay := time.Duration(1<<attempt) * 100 * time.Millisecond
+ select {
+ case <-time.After(delay):
+ case <-ctx.Done():
+ ch <- result{nil, ctx.Err()}
+ return
+ }
}
- lastErr = err
- time.Sleep(time.Duration(attempt+1) * 200 * time.Millisecond)
+ ch <- result{nil, fmt.Errorf("all retries failed for %s", addr)}
+ }(addr)
+ }
+ var lastErr error
+ for i := 0; i < len(servers); i++ {
+ res := <-ch
+ if res.err == nil {
+ return res.reply, nil
}
+ lastErr = res.err
}
return nil, lastErr
}